用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
原文地址:https://zhuanlan.zhihu.com/p/266311690
論文地址:https://arxiv.org/pdf/2010.11929.pdf
代碼地址:https://github.com/google-research/vision_transformer
用Transformer完全代替CNN
- 1. Story
- 2. Model
- a 將圖像轉化為序列化數據
- b Position embedding
- c Learnable embedding
- d Transformer encoder
- 3. 混合結構
- 4. Fine-tuning過程中高分辨率圖像的處理
- 5. 實驗
1. Story
近年來,Transformer已經成了NLP領域的標準配置,但是CV領域還是CNN(如ResNet, DenseNet等)占據了絕大多數的SOTA結果。
最近CV界也有很多文章將transformer遷移到CV領域,這些文章總的來說可以分為兩個大類:
- 將self-attention機制與常見的CNN架構結合;
- 用self-attention機制完全替代CNN。
本文采用的也是第2種思路。雖然已經有很多工作用self-attention完全替代CNN,且在理論上效率比較高,但是它們用了特殊的attention機制,無法從硬件層面加速,所以目前CV領域的SOTA結果還是被CNN架構所占據。
文章不同于以往工作的地方,就是盡可能地將NLP領域的transformer不作修改地搬到CV領域來。但是NLP處理的語言數據是序列化的,而CV中處理的圖像數據是三維的(長、寬和channels)。
所以我們需要一個方式將圖像這種三維數據轉化為序列化的數據。文章中,圖像被切割成一個個patch,這些patch按照一定的順序排列,就成了序列化的數據。(具體將在下面講述)
在實驗中,作者發現,在中等規模的數據集上(例如ImageNet),transformer模型的表現不如ResNets;而當數據集的規模擴大,transformer模型的效果接近或者超過了目前的一些SOTA結果。作者認為是大規模的訓練可以鼓勵transformer學到CNN結構所擁有的translation equivariance和locality.
2. Model
Vision Transformer (ViT)結構示意圖
模型的結構其實比較簡單,可以分成以下幾個部分來理解:
a 將圖像轉化為序列化數據
作者采用了了一個比較簡單的方式。如下圖所示。首先將圖像分割成一個個patch,然后將每個patch reshape成一個向量,得到所謂的flattened patch。
具體地,如果圖片是H×W×CH\times W\times CH×W×C維,用P×PP\times PP×P大小的patch去分割圖片可以得到N個patch,那么每個patch的shape就是P×P×CP\times P\times CP×P×C,轉化為向量就是P2CP^2CP2C維向量,將N個patch reshape后的向量concat在一起就得到了一個N×(P2C)N\times (P^2C)N×(P2C)的二維矩陣,相當于NLP中輸入transformer的詞向量。
- 分割圖像得到patch
從上面的過程可以看出,當patch的大小變化時(即 P 變化時),每個patch reshape后得到的 P2CP^2CP2C 維向量的長度也會變化。為了避免模型結構受到patch size的影響,作者對上述過程得到的flattened patches向量做了Linear Projection(如下圖所示),將不同長度的flattened patch向量轉化為固定長度的向量(記做D維向量)。
- 對flattened patches做linear projection
綜上,原本H×W×CH\times W\times CH×W×C維的圖片被轉化為N個D維的向量(或者一個N×DN\times DN×D維的二維矩陣)。
b Position embedding
- Position embedding
由于transformer模型本身是沒有位置信息的,和NLP中一樣,我們需要用position embedding將位置信息加到模型中去。
如上圖所示1,編號有0-9的紫色框表示各個位置的position embedding,而紫色框旁邊的粉色框則是經過linear projection之后的flattened patch向量。文中采用將position embedding(即圖中紫色框)和patch embedding(即圖中粉色框)相加的方式結合position信息。
c Learnable embedding
如果大家仔細看上圖,就會發現帶星號的粉色框(即0號紫色框右邊的那個)不是通過某個patch產生的。這個是一個learnable embedding(記作 XclassX_{class}Xclass? ),其作用類似于BERT中的[class] token。在BERT中,[class] token經過encoder后對應的結果作為整個句子的表示;類似地,這里 XclassX_{class}Xclass? 經過encoder后對應的結果也作為整個圖的表示。
至于為什么BERT或者這篇文章的ViT要多加一個token呢?因為如果人為地指定一個embedding(例如本文中某個patch經過Linear Projection得到的embedding)經過encoder得到的結果作為整體的表示,則不可避免地會使得整體表示偏向于這個指定embedding的信息(例如圖像的表示偏重于反映某個patch的信息)。而這個新增的token沒有語義信息(即在句子中與任何的詞無關,在圖像中與任何的patch無關),所以不會造成上述問題,能夠比較公允地反映全圖的信息。
d Transformer encoder
Transformer Encoder結構和NLP中transformer結構基本上相同,所以這里只給出其結構圖,和公式化的計算過程,也是順便用公式表達了之前所說的幾個部分內容。
Transformer Encoder的結構如下圖所示:
對于Encoder的第 lll 層,記其輸入為zl?1z_{l-1}zl?1?,輸出為zlz_lzl?,則計算過程為:
其中MSA為Multi-Head Self-Attention(即Transformer Encoder結構圖中的綠色框),MLP為Multi-Layer Perceptron(即Transformer Encoder結構圖中的藍色框),LN為Layer Norm(即Transformer Encoder結構圖中的黃色框)。
Encoder第一層的輸入z0z_0z0?是通過下面的公式得到的:
其中Xp1,...,XpNX_p^1,...,X_p^NXp1?,...,XpN?即未Linear Projection后的patch embedding(都是p2Cp^2Cp2C維)
3. 混合結構
文中還提出了一個比較有趣的解決方案,將transformer和CNN結合,即將ResNet的中間層的feature map作為transformer的輸入。
和之前所說的將圖片分成patch然后reshape成sequence不同的是,在這種方案中,作者直接將ResNet某一層的feature map reshape成sequence,再通過Linear Projection變為Transformer輸入的維度,然后直接輸入進Transformer中。
4. Fine-tuning過程中高分辨率圖像的處理
在Fine-tuning到下游任務時,當圖像的分辨率增大時(即圖像的長和寬增大時),如果保持patch大小不變,得到的patch個數將增加(記分辨率增大后新的patch個數為 N′N^{'}N′ )。但是由于在pretrain時,position embedding的個數和pretrain時分割得到的patch個數(即上文中的 N )相同。則多出來的 N′?NN^{'}-NN′?N 個positioin embedding在pretrain中是未定義或者無意義的。
為了解決這個問題,文章中提出用2D插值的方法,基于原圖中的位置信息,將pretrain中的 N 個position embedding插值成N′N^{'}N′ 個。這樣在得到 N′N^{'}N′ 個position embedding的同時也保證了position embedding的語義信息。
5. 實驗
實驗部分由于涉及到的細節較多就不具體介紹了,大家如果感興趣可以參看原文。(不得不說Google的實驗能力和鈔能力不是一般人能比的…)
主要的實驗結論在story中就已經介紹過了,這里復制粘貼一下:在中等規模的數據集上(例如ImageNet),transformer模型的表現不如ResNets;而當數據集的規模擴大,transformer模型的效果接近或者超過了目前的一些SOTA結果。
比較有趣的是,作者還做了很多其他的分析來解釋transfomer的合理性。大家如果感興趣也可以參看原文,這里放幾張文章中的圖。
總結
以上是生活随笔為你收集整理的用Transformer完全代替CNN:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【腾讯面试题】Nginx
- 下一篇: leetcode303 Range Su