论文阅读:Spatial Transformer Networks
文章目錄
- 1 概述
- 2 模型說明
- 2.1 Localisation Network
- 2.2 Parameterised Sampling Grid
- 3 模型效果
- 參考資料
1 概述
CNN的機理使得CNN在處理圖像時可以做到transition invariant,卻沒法做到scaling invariant和rotation invariant。即使是現(xiàn)在火熱的transformer搭建的圖像模型(swin transformer, vision transformer),也沒辦法做到這兩點。因為他們在處理時都會參考圖像中物體的相對大小和位置方向。不同大小和不同方向的物體,對網(wǎng)絡來說是不同的東西。這個問題在這篇文章統(tǒng)稱為spatially invariant問題。甚至不同方向的物體,本身就真的是不同的東西,比如文字。
其實pooling layer有一定程度上解決了這個問題,因為在做max pooling或者average pooling的時候,只要這個特征在,就可以提取出來,在什么位置,pooling layer是不關心的。但是pooling的kernel size通常都比較小,需要做到大物體的spatially invariant是很難的,除非網(wǎng)絡特別深。
STN(spatial transformer network)的提出,就是為了解決spatially invariant問題這個問題。它的主要思想很簡單,就是訓練一個可以把物體線性變換到模型正常的大小和方向的前置網(wǎng)絡。這個網(wǎng)絡可以前置于任何的圖像網(wǎng)絡中,即插即用。同時也可以和整個網(wǎng)絡一起訓練。
再說的直白一點,就是一個可以根據(jù)輸入圖片輸出仿射變換參數(shù)的網(wǎng)絡。
圖1-1是STN網(wǎng)絡的一個結果示意圖,圖1-1(a)是輸入圖片,圖1-1(b)是STN中的localisation網(wǎng)絡檢測到的物體區(qū)域,圖1-1?是STN對檢測到的區(qū)域進行線性變換后輸出,圖1-1(d)是有STN的分類網(wǎng)絡的最終輸出。
2 模型說明
STN(spatial transformer network)更準確地說應該是STL(spatial transformer layer),它就是網(wǎng)絡中的一層,并且可以在任何兩層之間添加一個或者多個。如下圖2-1所示,spatial transformer主要由兩部分組成,分別是localisation net和grid generator。
2.1 Localisation Network
我們的目的是把第l?1l-1l?1層的第nnn行,第mmm列的特征移動到第lll層的某行某列。如下圖2-2所示,一個3×33 \times 33×3的特征要變換的話,第lll層的每個位置都可以表示為l?1l-1l?1層的特征的加權和。通過控制權重wnm,ijlw_{nm,ij}^lwnm,ijl?就可以實現(xiàn)任何仿射變換。
但如果直接加一層全連接讓模型學的話,模型可能學出來的就不是仿射變換了,參數(shù)量也很大,很難學,很難控制。所有就設計了一個localisation net,直接讓模型學仿射變換的參數(shù),這相當于是一個inductive bias。
localisation net的輸入是前一層的特征,輸出是仿射變換的參數(shù),如果是平面的放射變換就是6個參數(shù),通過這六個參數(shù)可以控制整個圖像的平移,旋轉,縮放。
圖2-3中的[a,b,c,d,e,f][a,b,c,d,e,f][a,b,c,d,e,f]參數(shù)就是localisation net的輸出。核心公式就是
[x′y′]=[abecdf][xy1](2-1)\left [ \begin{matrix} x' \\ y' \end{matrix} \right ] = \left [ \begin{matrix} a & b & e \\ c & d & f \end{matrix} \right ] \left [ \begin{matrix} x \\ y \\ 1\end{matrix} \right ] \tag{2-1} [x′y′?]=[ac?bd?ef?]???xy1????(2-1)
其中,xxx和yyy是當前層的坐標,x′x'x′和y′y'y′是前一層的坐標,aaa和ddd主要控制縮放,bbb和ccc主要控制旋轉,eee和fff主要控制平移。
2.2 Parameterised Sampling Grid
localisation net輸出了仿射變換參數(shù)之后,式(2?1)(2-1)(2?1)告訴了我們當前層(x,y)(x,y)(x,y)這個位置的特征是前一層的(x′,y′)(x', y')(x′,y′)位置的特征拿過來的。但是,如圖2-4中的例子所示,(x′,y′)(x', y')(x′,y′)可能是小數(shù),位置需要是正整數(shù),如果采用取整的操作的話,網(wǎng)絡就會變得不可梯度下降,沒法更新參數(shù)了。
我們想要的是,當[a,b,c,d,e,f][a,b,c,d,e,f][a,b,c,d,e,f]發(fā)生微小的變化之后,下一層的特征也發(fā)生變化,這樣才可以保證可以梯度下降。
于是,作者就采用了插值的方法來進行采樣。比如當坐標為[1.6,2.4][1.6, 2.4][1.6,2.4]時,就用[a12l?1,a13l?1,a22l?1,a23l?1][a_{12}^{l-1}, a_{13}^{l-1}, a_{22}^{l-1}, a_{23}^{l-1}][a12l?1?,a13l?1?,a22l?1?,a23l?1?]這幾個值進行插值。這樣一來[a,b,c,d,e,f][a,b,c,d,e,f][a,b,c,d,e,f]發(fā)生微小的變化之后,[x,y][x,y][x,y]位置采樣得到的值也會有變化了。這也使得spatial transformer可以放到任何層,跟整個網(wǎng)絡一起訓練。
3 模型效果
(1)文字識別任務
圖3-1表示了STN加入到文字識別任務時帶來的效果提升,圖3-1左側是不同模型在SVHN數(shù)據(jù)集上的錯誤率。這相當于是一個只有數(shù)字的文字識別任務。
不過用在OCR任務中,STN其實有點雞肋。OCR有文字檢測和文字識別兩個部分,一般文字檢測部分會帶有檢測框四個頂點的坐標,我們直接把文字檢測的結果進行仿射變換再去文字識別就可以了,不需要在文字識別時再加一個STN。一個是難訓練,再就是推理性能下降。
(2)圖像分類任務
圖3-2表示了STN在圖像識別任務的效果,圖3-2左側表示了不同模型在Caltech-UCSD Birds-200-2011數(shù)據(jù)集上的準確率。
2×ST?CNN2 \times ST-CNN2×ST?CNN表示在同一層是用了兩個不同的STL,4×ST?CNN4 \times ST-CNN4×ST?CNN表示在同一層是用了四個不同的STL。圖3-2右側中的方框表示了不同STL要進行放射變換的位置。
可以看到不同的STL關注的鳥的部位也是不一樣的,一個一直關注頭部,一個一直關注身子。這就相當于是一個attention,把感興趣的區(qū)域提取出來了。
還有一個地方是,這里的方框都是正的,這其實是因為作者把仿射變換中的參數(shù)[b,c][b, c][b,c]人為置0了,變成了
[x′y′]=[a0e0df][xy1](3-1)\left [ \begin{matrix} x' \\ y' \end{matrix} \right ] = \left [ \begin{matrix} a & 0 & e \\ 0 & d & f \end{matrix} \right ] \left [ \begin{matrix} x \\ y \\ 1\end{matrix} \right ] \tag{3-1} [x′y′?]=[a0?0d?ef?]???xy1????(3-1)
可見STN是一個可以融入到很多圖像模型,且可拓展性高的模塊。
參考資料
[1] Spatial Transformer Networks
[2] 李宏毅-Spatial Transformer Layer
總結
以上是生活随笔為你收集整理的论文阅读:Spatial Transformer Networks的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 五、pink老师的学习笔记——CSS精灵
- 下一篇: Node 中的path模块