1.引文
深度學習的比賽中,圖片分類是很常見的比賽,同時也是很難取得特別高名次的比賽,因為圖片分類已經被大家研究的很透徹,一些開源的網絡很容易取得高分。如果大家還掌握不了使用開源的網絡進行訓練,再慢慢去模型調優,很難取得較好的成績。
我們在[PyTorch小試牛刀]實戰六·準備自己的數據集用于訓練講解了如何制作自己的數據集用于訓練,這個教程在此基礎上,進行訓練與應用。
(實戰六鏈接:
https://blog.csdn.net/xiaosongshine/article/details/85225873)
2.數據介紹
數據下載地址:
https://download.csdn.net/download/xiaosongshine/11128410
這次的實戰使用的數據是交通標志數據集,共有62類交通標志。其中訓練集數據有4572張照片(每個類別大概七十個),測試數據集有2520張照片(每個類別大概40個)。數據包含兩個子目錄分別train與test:
為什么還需要測試數據集呢?這個測試數據集不會拿來訓練,是用來進行模型的評估與調優。
train與test每個文件夾里又有62個子文件夾,每個類別在同一個文件夾內:
我從中打開一個文件間,把里面圖片展示出來:
其中每張照片都類似下面的例子,100*100*3的大小。100是照片的照片的長和寬,3是什么呢?這其實是照片的色彩通道數目,RGB。彩色照片存儲在計算機里就是以三維數組的形式。我們送入網絡的也是這些數組。
3.網絡構建
1importtorchast 2importtorchvisionastv 3importos 4importtime 5importnumpyasnp 6fromtqdmimporttqdm 7 8 9classDefaultConfigs(object):1011data_dir="./traffic-sign/"12data_list=["train","test"]1314lr=0.00115epochs=1016num_classes=6217image_size=22418batch_size=4019channels=320gpu="0"21train_len=457222test_len=252023use_gpu=t.cuda.is_available()2425config=DefaultConfigs()
2.數據準備,采用PyTorch提供的讀取方式
注意一點Train數據需要進行隨機裁剪,Test數據不要進行裁剪了
1normalize=tv.transforms.Normalize(mean=[0.485,0.456,0.406], 2std=[0.229,0.224,0.225] 3) 4 5transform={ 6config.data_list[0]:tv.transforms.Compose( 7[tv.transforms.Resize([224,224]),tv.transforms.CenterCrop([224,224]), 8tv.transforms.ToTensor(),normalize]#tv.transforms.Resize用于重設圖片大小 9),10config.data_list[1]:tv.transforms.Compose(11[tv.transforms.Resize([224,224]),tv.transforms.ToTensor(),normalize]12)13}1415datasets={16x:tv.datasets.ImageFolder(root=os.path.join(config.data_dir,x),transform=transform[x])17forxinconfig.data_list18}1920dataloader={21x:t.utils.data.DataLoader(dataset=datasets[x],22batch_size=config.batch_size,23shuffle=True24)25forxinconfig.data_list26}
3.構建網絡模型(使用resnet18進行遷移學習,訓練參數為最后一個全連接層 t.nn.Linear(512,num_classes))
1defget_model(num_classes): 2 3model=tv.models.resnet18(pretrained=True) 4forparmainmodel.parameters(): 5parma.requires_grad=False 6model.fc=t.nn.Sequential( 7t.nn.Dropout(p=0.3), 8t.nn.Linear(512,num_classes) 9)10return(model)
如果電腦硬件支持,可以把下述代碼屏蔽,則訓練整個網絡,最終準確率會上升,訓練數據會變慢。
1forparmainmodel.parameters():2parma.requires_grad=False
模型輸出
1ResNet( 2(conv1):Conv2d(3,64,kernel_size=(7,7),stride=(2,2),padding=(3,3),bias=False) 3(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True) 4(relu):ReLU(inplace) 5(maxpool):MaxPool2d(kernel_size=3,stride=2,padding=1,dilation=1,ceil_mode=False) 6(layer1):Sequential( 7(0):BasicBlock( 8(conv1):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False) 9(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)10(relu):ReLU(inplace)11(conv2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)12(bn2):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)13)14(1):BasicBlock(15(conv1):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)16(bn1):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)17(relu):ReLU(inplace)18(conv2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)19(bn2):BatchNorm2d(64,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)20)21)22(layer2):Sequential(23(0):BasicBlock(24(conv1):Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)25(bn1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)26(relu):ReLU(inplace)27(conv2):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)28(bn2):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)29(downsample):Sequential(30(0):Conv2d(64,128,kernel_size=(1,1),stride=(2,2),bias=False)31(1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)32)33)34(1):BasicBlock(35(conv1):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)36(bn1):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)37(relu):ReLU(inplace)38(conv2):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)39(bn2):BatchNorm2d(128,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)40)41)42(layer3):Sequential(43(0):BasicBlock(44(conv1):Conv2d(128,256,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)45(bn1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)46(relu):ReLU(inplace)47(conv2):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)48(bn2):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)49(downsample):Sequential(50(0):Conv2d(128,256,kernel_size=(1,1),stride=(2,2),bias=False)51(1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)52)53)54(1):BasicBlock(55(conv1):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)56(bn1):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)57(relu):ReLU(inplace)58(conv2):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)59(bn2):BatchNorm2d(256,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)60)61)62(layer4):Sequential(63(0):BasicBlock(64(conv1):Conv2d(256,512,kernel_size=(3,3),stride=(2,2),padding=(1,1),bias=False)65(bn1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)66(relu):ReLU(inplace)67(conv2):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)68(bn2):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)69(downsample):Sequential(70(0):Conv2d(256,512,kernel_size=(1,1),stride=(2,2),bias=False)71(1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)72)73)74(1):BasicBlock(75(conv1):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)76(bn1):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)77(relu):ReLU(inplace)78(conv2):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1),bias=False)79(bn2):BatchNorm2d(512,eps=1e-05,momentum=0.1,affine=True,track_running_stats=True)80)81)82(avgpool):AvgPool2d(kernel_size=7,stride=1,padding=0)83(fc):Sequential(84(0):Dropout(p=0.3)85(1):Linear(in_features=512,out_features=62,bias=True)86)87)
4.訓練模型(支持自動GPU加速)
1deftrain(epochs): 2 3model=get_model(config.num_classes) 4print(model) 5loss_f=t.nn.CrossEntropyLoss() 6if(config.use_gpu): 7model=model.cuda() 8loss_f=loss_f.cuda() 910opt=t.optim.Adam(model.fc.parameters(),lr=config.lr)11time_start=time.time()1213forepochinrange(epochs):14train_loss=[]15train_acc=[]16test_loss=[]17test_acc=[]18model.train(True)19print("Epoch{}/{}".format(epoch+1,epochs))20forbatch,datasintqdm(enumerate(iter(dataloader["train"]))):21x,y=datas22if(config.use_gpu):23x,y=x.cuda(),y.cuda()24y_=model(x)25#print(x.shape,y.shape,y_.shape)26_,pre_y_=t.max(y_,1)27pre_y=y28#print(y_.shape)29loss=loss_f(y_,pre_y)30#print(y_.shape)31acc=t.sum(pre_y_==pre_y)3233loss.backward()34opt.step()35opt.zero_grad()36if(config.use_gpu):37loss=loss.cpu()38acc=acc.cpu()39train_loss.append(loss.data)40train_acc.append(acc)41#if((batch+1)%5==0):42time_end=time.time()43print("Batch{},Trainloss:{:.4f},Trainacc:{:.4f},Time:{}"\44.format(batch+1,np.mean(train_loss)/config.batch_size,np.mean(train_acc)/config.batch_size,(time_end-time_start)))45time_start=time.time()4647model.train(False)48forbatch,datasintqdm(enumerate(iter(dataloader["test"]))):49x,y=datas50if(config.use_gpu):51x,y=x.cuda(),y.cuda()52y_=model(x)53#print(x.shape,y.shape,y_.shape)54_,pre_y_=t.max(y_,1)55pre_y=y56#print(y_.shape)57loss=loss_f(y_,pre_y)58acc=t.sum(pre_y_==pre_y)5960if(config.use_gpu):61loss=loss.cpu()62acc=acc.cpu()6364test_loss.append(loss.data)65test_acc.append(acc)66print("Batch{},Testloss:{:.4f},Testacc:{:.4f}".format(batch+1,np.mean(test_loss)/config.batch_size,np.mean(test_acc)/config.batch_size))6768t.save(model,str(epoch+1)+"ttmodel.pkl")69707172if__name__=="__main__":73train(config.epochs)
訓練結果如下:
1Epoch1/10 2115it[00:48,2.63it/s] 3Batch115,Trainloss:0.0590,Trainacc:0.4635,Time:48.985504150390625 463it[00:24,2.62it/s] 5Batch63,Testloss:0.0374,Testacc:0.6790,Time:24.648272275924683 6Epoch2/10 7115it[00:45,3.22it/s] 8Batch115,Trainloss:0.0271,Trainacc:0.7576,Time:45.68823838233948 963it[00:23,2.62it/s]10Batch63,Testloss:0.0255,Testacc:0.7524,Time:23.27178287506103511Epoch3/1012115it[00:45,3.19it/s]13Batch115,Trainloss:0.0181,Trainacc:0.8300,Time:45.926485061645511463it[00:23,2.60it/s]15Batch63,Testloss:0.0212,Testacc:0.7861,Time:23.8078927993774416Epoch4/1017115it[00:45,3.28it/s]18Batch115,Trainloss:0.0138,Trainacc:0.8767,Time:45.275250196456911963it[00:23,2.57it/s]20Batch63,Testloss:0.0173,Testacc:0.8385,Time:23.73632144927978521Epoch5/1022115it[00:44,3.22it/s]23Batch115,Trainloss:0.0112,Trainacc:0.8950,Time:44.9836382865905762463it[00:22,2.69it/s]25Batch63,Testloss:0.0156,Testacc:0.8520,Time:22.79007434844970726Epoch6/1027115it[00:44,3.19it/s]28Batch115,Trainloss:0.0095,Trainacc:0.9159,Time:45.104269504547122963it[00:22,2.77it/s]30Batch63,Testloss:0.0158,Testacc:0.8214,Time:22.8041245937347431Epoch7/1032115it[00:45,2.95it/s]33Batch115,Trainloss:0.0081,Trainacc:0.9280,Time:45.304390430450443463it[00:23,2.66it/s]35Batch63,Testloss:0.0139,Testacc:0.8528,Time:23.12237954139709536Epoch8/1037115it[00:44,3.23it/s]38Batch115,Trainloss:0.0073,Trainacc:0.9300,Time:44.3047628402709963963it[00:22,2.74it/s]40Batch63,Testloss:0.0142,Testacc:0.8496,Time:22.80183553695678741Epoch9/1042115it[00:43,3.19it/s]43Batch115,Trainloss:0.0068,Trainacc:0.9361,Time:44.084140300750734463it[00:23,2.44it/s]45Batch63,Testloss:0.0142,Testacc:0.8437,Time:23.60441923141479546Epoch10/1047115it[00:46,3.12it/s]48Batch115,Trainloss:0.0063,Trainacc:0.9337,Time:46.765970468521124963it[00:24,2.65it/s]50Batch63,Testloss:0.0130,Testacc:0.8591,Time:24.64351773262024
訓練10個Epoch,測試集準確率可以到達0.86,已經達到不錯效果。通過修改參數,增加訓練,可以達到更高的準確率。
-
數據集
+關注
關注
4文章
1205瀏覽量
24641 -
pytorch
+關注
關注
2文章
803瀏覽量
13146
原文標題:實戰:掌握PyTorch圖片分類的簡明教程 | 附完整代碼
文章出處:【微信號:rgznai100,微信公眾號:rgznai100】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論