torchvision分類介紹
Torchvision高版本支持各種SOTA的圖像分類模型,同時還支持不同數據集分類模型的預訓練模型的切換。使用起來十分方便快捷,Pytroch中支持兩種遷移學習方式,分別是:
- Finetune模式 基于預訓練模型,全鏈路調優參數 - 凍結特征層模式 這種方式只修改輸出層的參數,CNN部分的參數凍結上述兩種遷移方式,分別適合大量數據跟少量數據,前一種方式計算跟訓練時間會比第二種方式要長點,但是針對大量自定義分類數據效果會比較好。
自定義分類模型修改與訓練
加載模型之后,feature_extracting 為true表示凍結模式,否則為finetune模式,相關的代碼如下:
def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False以resnet18為例,修改之后的自定義訓練代碼如下:
model_ft=models.resnet18(pretrained=True) num_ftrs=model_ft.fc.in_features #Herethesizeofeachoutputsampleissetto5. #Alternatively,itcanbegeneralizedtonn.Linear(num_ftrs,len(class_names)). model_ft.fc=nn.Linear(num_ftrs,5) model_ft=model_ft.to(device) criterion=nn.CrossEntropyLoss() #Observethatallparametersarebeingoptimized optimizer_ft=optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9) #DecayLRbyafactorof0.1every7epochs exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) model_ft=train_model(model_ft,criterion,optimizer_ft,exp_lr_scheduler, num_epochs=25)
數據集是flowers-dataset,有五個分類分別是:
daisy dandelion roses sunflowers tulips
全鏈路調優,遷移學習訓練CNN部分的權重參數
Epoch0/24 ---------- trainLoss:1.3993Acc:0.5597 validLoss:1.8571Acc:0.7073 Epoch1/24 ---------- trainLoss:1.0903Acc:0.6580 validLoss:0.6150Acc:0.7805 Epoch2/24 ---------- trainLoss:0.9095Acc:0.6991 validLoss:0.4386Acc:0.8049 Epoch3/24 ---------- trainLoss:0.7628Acc:0.7349 validLoss:0.9111Acc:0.7317 Epoch4/24 ---------- trainLoss:0.7107Acc:0.7669 validLoss:0.4854Acc:0.8049 Epoch5/24 ---------- trainLoss:0.6231Acc:0.7793 validLoss:0.6822Acc:0.8049 Epoch6/24 ---------- trainLoss:0.5768Acc:0.8033 validLoss:0.2748Acc:0.8780 Epoch7/24 ---------- trainLoss:0.5448Acc:0.8110 validLoss:0.4440Acc:0.7561 Epoch8/24 ---------- trainLoss:0.5037Acc:0.8170 validLoss:0.2900Acc:0.9268 Epoch9/24 ---------- trainLoss:0.4836Acc:0.8360 validLoss:0.7108Acc:0.7805 Epoch10/24 ---------- trainLoss:0.4663Acc:0.8369 validLoss:0.5868Acc:0.8049 Epoch11/24 ---------- trainLoss:0.4276Acc:0.8504 validLoss:0.6998Acc:0.8293 Epoch12/24 ---------- trainLoss:0.4299Acc:0.8529 validLoss:0.6449Acc:0.8049 Epoch13/24 ---------- trainLoss:0.4256Acc:0.8567 validLoss:0.7897Acc:0.7805 Epoch14/24 ---------- trainLoss:0.4062Acc:0.8559 validLoss:0.5855Acc:0.8293 Epoch15/24 ---------- trainLoss:0.4030Acc:0.8545 validLoss:0.7336Acc:0.7805 Epoch16/24 ---------- trainLoss:0.3786Acc:0.8730 validLoss:1.0429Acc:0.7561 Epoch17/24 ---------- trainLoss:0.3699Acc:0.8763 validLoss:0.4549Acc:0.8293 Epoch18/24 ---------- trainLoss:0.3394Acc:0.8788 validLoss:0.2828Acc:0.9024 Epoch19/24 ---------- trainLoss:0.3300Acc:0.8834 validLoss:0.6766Acc:0.8537 Epoch20/24 ---------- trainLoss:0.3136Acc:0.8906 validLoss:0.5893Acc:0.8537 Epoch21/24 ---------- trainLoss:0.3110Acc:0.8901 validLoss:0.4909Acc:0.8537 Epoch22/24 ---------- trainLoss:0.3141Acc:0.8931 validLoss:0.3930Acc:0.9024 Epoch23/24 ---------- trainLoss:0.3106Acc:0.8887 validLoss:0.3079Acc:0.9024 Epoch24/24 ---------- trainLoss:0.3143Acc:0.8923 validLoss:0.5122Acc:0.8049 Trainingcompletein25m34s BestvalAcc:0.926829
凍結CNN部分,只訓練全連接分類權重
Paramstolearn: fc.weight fc.bias Epoch0/24 ---------- trainLoss:1.0217Acc:0.6465 validLoss:1.5317Acc:0.8049 Epoch1/24 ---------- trainLoss:0.9569Acc:0.6947 validLoss:1.2450Acc:0.6829 Epoch2/24 ---------- trainLoss:1.0280Acc:0.6999 validLoss:1.5677Acc:0.7805 Epoch3/24 ---------- trainLoss:0.8344Acc:0.7426 validLoss:1.1053Acc:0.7317 Epoch4/24 ---------- trainLoss:0.9110Acc:0.7250 validLoss:1.1148Acc:0.7561 Epoch5/24 ---------- trainLoss:0.9049Acc:0.7346 validLoss:1.1541Acc:0.6341 Epoch6/24 ---------- trainLoss:0.8538Acc:0.7465 validLoss:1.4098Acc:0.8293 Epoch7/24 ---------- trainLoss:0.9041Acc:0.7349 validLoss:0.9604Acc:0.7561 Epoch8/24 ---------- trainLoss:0.8885Acc:0.7468 validLoss:1.2603Acc:0.7561 Epoch9/24 ---------- trainLoss:0.9257Acc:0.7333 validLoss:1.0751Acc:0.7561 Epoch10/24 ---------- trainLoss:0.8637Acc:0.7492 validLoss:0.9748Acc:0.7317 Epoch11/24 ---------- trainLoss:0.8686Acc:0.7517 validLoss:1.0194Acc:0.8049 Epoch12/24 ---------- trainLoss:0.8492Acc:0.7572 validLoss:1.0378Acc:0.7317 Epoch13/24 ---------- trainLoss:0.8773Acc:0.7432 validLoss:0.7224Acc:0.8049 Epoch14/24 ---------- trainLoss:0.8919Acc:0.7473 validLoss:1.3564Acc:0.7805 Epoch15/24 ---------- trainLoss:0.8634Acc:0.7490 validLoss:0.7822Acc:0.7805 Epoch16/24 ---------- trainLoss:0.8069Acc:0.7644 validLoss:1.4132Acc:0.7561 Epoch17/24 ---------- trainLoss:0.8589Acc:0.7492 validLoss:0.9812Acc:0.8049 Epoch18/24 ---------- trainLoss:0.7677Acc:0.7688 validLoss:0.7176Acc:0.8293 Epoch19/24 ---------- trainLoss:0.8044Acc:0.7514 validLoss:1.4486Acc:0.7561 Epoch20/24 ---------- trainLoss:0.7916Acc:0.7564 validLoss:1.0575Acc:0.8049 Epoch21/24 ---------- trainLoss:0.7922Acc:0.7647 validLoss:1.0406Acc:0.7805 Epoch22/24 ---------- trainLoss:0.8187Acc:0.7647 validLoss:1.0965Acc:0.7561 Epoch23/24 ---------- trainLoss:0.8443Acc:0.7503 validLoss:1.6163Acc:0.7317 Epoch24/24 ---------- trainLoss:0.8165Acc:0.7583 validLoss:1.1680Acc:0.7561 Trainingcompletein20m7s BestvalAcc:0.829268
測試結果:
零代碼訓練演示
我已經完成torchvision中分類模型自定義數據集遷移學習的代碼封裝與開發,支持基于收集到的數據集,零代碼訓練,生成模型。圖示如下:
-
數據
+關注
關注
8文章
6892瀏覽量
88828 -
模型
+關注
關注
1文章
3172瀏覽量
48714 -
遷移學習
+關注
關注
0文章
74瀏覽量
5557
原文標題:tochvision輕松支持十種圖像分類模型遷移學習
文章出處:【微信號:CVSCHOOL,微信公眾號:OpenCV學堂】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論