編者按:Udacity深度強化學習課程負責人Alexis Cook講解了全局平均池化(GAP)的概念,并演示了為分類問題訓練的GAP-CNN在目標定位方面的能力。
圖像分類任務中,卷積神經網絡(CNN)架構的常見選擇是重復的卷積模塊(卷積層加池化層),之后是兩層以上的密集層(全連接層)。最后密集層使用softmax激活函數,每個節點對應一個類別。
比如,VGG-16的架構:
譯者注:上圖中,黑色的是卷積層(ReLU激活),紅色的是最大池化層,藍色的是全連接層(ReLU激活),金色的是softmax層。
運行以下代碼,可以得到VGG-16模型的網絡層清單:(譯者注:需要安裝Keras)
python -c 'from keras.applications.vgg16 import VGG16; VGG16().summary()'
輸出為:
你會注意到有5個卷積模塊(兩到三個卷積層,之后是一個最大池化層)。接著,扁平化最后一個最大池化層,后面跟著三個密集層。注意模型的大部分參數屬于全連接層!
你大概可以想見,這樣的架構有過擬合訓練數據集的風險。實踐中會使用dropout層以避免過擬合。
全局平均池化
最近幾年,人們開始使用全局平均池化(global average pooling,GAP)層,通過降低模型的參數數量來最小化過擬合效應。類似最大池化層,GAP層可以用來降低三維張量的空間維度。然而,GAP層的降維更加激進,一個h × w × d的張量會被降維至1 × 1 × d。GAP層通過取平均值映射每個h × w的特征映射至單個數字。
在最早提出GAP層的網中網(Network in Network)架構中,最后的最大池化層的輸出傳入GAP層,GAP層生成一個向量,向量的每一項表示分類任務中的一個類別。接著應用softmax激活函數生成每個分類的預測概率。如果你打算參考原論文(arXiv:1312.4400),我特別建議你看下3.2節“全局平均池化”。
ResNet-50模型沒這么激進;并沒有完全移除密集層,而是在GAP層之后加上一個帶softmax激活函數的密集層,生成預測分類。
目標定位
2016年年中,MIT的研究人員展示了為分類任務訓練的包含GAP層的CNN(GAP-CNN),同樣可以用于目標定位。也就是說,GAP-CNN不僅告訴我們圖像中包含的目標是什么東西,它還可以告訴我們目標在圖像中的什么地方,而且我們不需要額外為此做什么!定位表示為熱圖(分類激活映射),其中的色彩編碼方案標明了GAP-CNN進行目標識別任務相對重要的區域。
我根據Bolei Zhou等的論文(arXiv:1512.04150)探索了預訓練的ResNet-50模型的定位能力(代碼見GitHub:alexisbcook/ResNetCAM-keras)。主要的思路是GAP層之前的最后一層的每個激活映射起到了解碼圖像中的不同位置的模式的作用。我們只需將這些檢測到的模式轉換為檢測到的目標,就可以得到每張圖像的分類激活映射。
GAP層中的每個節點對應不同的激活映射,連接GAP層和最后的密集層的權重編碼了每個激活映射對預測目標分類的貢獻。將激活映射中的每個檢測到的模式的貢獻(對預測目標分類更重要的檢測到的模式獲得更多權重)累加起來,就得到了分類激活映射。
代碼如何運作
運行以下代碼檢視ResNet-50的架構:
python -c 'from keras.applications.resnet50 import ResNet50; ResNet50().summary()'
輸出如下:
注意,和VGG-16模型不同,并非大部分可訓練參數都位于網絡最頂上的全連接層中。
網絡最后的Activation、AveragePooling2D、Dense層是我們最感興趣的(上圖高亮部分)。實際上AveragePooling2D層是一個GAP層!
我們從Activation層開始。這一層包含2048個7 × 7維的激活映射。讓我們用fk表示第k個激活映射,其中k ∈{1,…,2048}。
接下來的AceragePooling2D層,也就是GAP層,通過取每個激活映射的平均值,將前一層的輸出大小降至(1,1,2048)。接下來的Flatten層只不過是扁平化輸入,沒有導致之前GAP層中包含信息的任何變動。
ResNet-50預測的每個目標類別對應最終的Dense層的每個節點,并且每個節點都和之前的Flatten層的各個節點相連。讓我們用wk表示連接Flatten層的第k個節點和對應預測圖像類別的輸出節點的權重。
接著,為了得到分類激活映射,我們只需計算:
我們可以將這些分類激活映射繪制在選定的圖像上,以探索ResNet-50的定位能力。為了便于和原圖比較,我們應用了雙線性上采樣,將激活映射的大小變為224 × 224.
如果你想在你自己的目標定位問題上應用這些代碼,可以訪問GitHub:https://github.com/alexisbcook/ResNetCAM-keras
-
神經網絡
+關注
關注
42文章
4765瀏覽量
100568 -
圖像分類
+關注
關注
0文章
90瀏覽量
11907 -
強化學習
+關注
關注
4文章
266瀏覽量
11220
原文標題:用于目標定位的全局平均池化
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論