圖像識別-102 花卉分類
非常適合機器學習初學者的實戰項目,代碼放在(https://github.com/xiaoxijio/Flower-classification) ,歡迎大家過來對我的代碼指指點點
花卉數據集
因為數據集太多了,上傳不了,所以大家可以自己去kaggle(https://www.kaggle.com/datasets/nunenuh/pytorch-challange-flower-dataset) 下載(下載挺快的,三百多M)
這是下載後包含的數據
102花卉數據集
我是放在'data/flower'目錄下,大家如果不放在這個目錄下更改一下我代碼裡文件位置就行。
那麼好!話歸正題,如何白嫖別人訓練好的模型
現在神經網絡發展那麼快,我們還自己從零開始敲摸著怎麼構建神經網絡那就大清亡了!垂涎欲滴地看著別人訓練好的模型,不如直接零幀起手,順手(紅色警告!做合法牛馬)。小小心思,Pytorch(https://pytorch.org/vision/stable/models.html) 早已看穿,為了促進人類文明的發展,與其..不如贈人玫瑰手留餘香。目前已有如下這些零元購:
分類模型
那麼好!開始我們的模型調教!
什麼導入數據,數據增強我就不說了,代碼裡註釋的很清楚,主要說一下如何調教
比如我想用resnet(https://pytorch.org/vision/stable/models/resnet.html) 網絡,我們去官網一看,官網提供瞭如下幾個
ResNet模型
那我們肯定不選對的,只選貴的。resnet152,看著就比別人牛逼,就是它了
1、調用前需要初始化一下
初始化模型
比如我們的分類任務時102種花卉,而別人的模型輸出是1000,所以我們首先要將別人模型的輸出層size改成我們自己的size。
當然,我們都用別人的東西,也要適配別人,不能純調教,那太喪心病狂了。比如resnet的架構在ImageNet上預訓練時,輸入尺寸就是224×224,那我們最好也將圖片尺寸調整為224×224(如果你偏不,也不是不行,真是拿你沒辦法呢)
2、凍結模型的特徵提取層,只訓練最後一層
能看這篇內容的,大都是發憤圖強的學生,學生哪有那麼好的gpu去訓練模型。像現在這些牛逼哄哄的神經網絡,人家頂級機器訓練都是按天計算的,更何況...嗚嗚嗚
既然我們要白嫖,那就把白嫖精神堅持到底!
凍結參數
機智靈敏的盒友們已經發現,這函數在上面初始化用到了。
我們把前面那些亂七八糟的需要調整的參數全給凍起來,就用人家訓練好的模型參數,給自己留個最後的全連接層訓練就行啦,我們就已經很努力啦。然後把需要訓練的全連接層參數給優化器,讓它去慢慢優化。
訓練過程和驗證過程我就不說了,人家要說的話都在酒裡,我要說的話都在代碼裡,詳細移步github。在訓練過程中,將訓練效果比較好的數據給保存下來哦,方便我們下次訓練的時候不用從零開始了(白嫖別人更要白嫖自己)
看看訓練效果
剛訓練準確率就達到了驚人的無地自容
訓練了20個epoch
我就跑了個20輪,沒多跑,大家電腦牛逼的可以多跑幾輪。大概跑了個二十幾分鍾吧,準確率從32% --> 73%
效果遠遠不夠啊,果然知識還得自己學才進腦子啊。但是不急,重新做人還來得及!
3、 加載預訓練好的模型,重新做人
我們將之前訓練效果好的模型數據加載出來,然後將之前投機取巧凍結的參數全部解凍,這一次,我要取回我的一切,電腦你要全力以赴啊!
在這一次訓練中,我們需要將學習率調的再低一點,讓模型在一個較優的情況下慢慢探索,不要一個步子邁大了,扯到不該扯到的
我們需要調整的代碼如下
feature_extract = False # 不凍結
load_model = True # 加載模型
optimizer_ft = optim.Adam(params_to_update, lr=1e-4) # 調小學習率
看看效果
這次我就跑了15輪,花費了大概半小時,效果一下提到了91%
好,我們已經學會如何使用別人的模型,那麼接下了你要去攻打... 沒錯,學會了1+1,你就要會巴拉巴拉(此處腦補非常複雜難懂反人類的數學問題)
話說讀書人的 怎麼能叫 這叫遷移學習!(遷移學習為我正名)