论文阅读 - Video Swin Transformer
文章目錄
- 1 概述
- 2 模型介紹
- 2.1 整體架構(gòu)
- 2.1.1 backbone
- 2.1.2 head
- 2.2 模塊詳述
- 2.2.1 Patch Partition
- 2.2.2 3D Patch Merging
- 2.2.3 W-MSA
- 2.2.4 SW-MSA
- 2.2.5 Relative Position Bias
- 3 模型效果
- 參考資料
1 概述
Vision Transformer是transformer應(yīng)用到圖像領(lǐng)域的一個里程碑,它將CNN完全剔除,只使用了transformer來完成網(wǎng)絡(luò)的搭建,并且在圖像分類任務(wù)中取得了state-of-art的效果。
Swin Transformer則更進(jìn)一步,引入了一些inductive biases,將CNN的結(jié)構(gòu)和transformer結(jié)合在了一起,使得transformer在圖像全領(lǐng)域都取得了state of art的效果。Swin Transformer中也有用到CNN,但是并不是把CNN當(dāng)做CNN來用的,只是用CNN的模塊來寫代碼比較方便。所以,也可以認(rèn)為是完全沒有使用CNN。
網(wǎng)上關(guān)于Swin Transformer的解讀多的不得了,這里來說說Swin Transformer在視頻領(lǐng)域的應(yīng)用,也就是Video Swin Transformer。如果非常熟悉Swin Transformer的話,那這篇文章就非常容易讀懂了,只是多了一個時間的維度,做attention和構(gòu)建window的時候略有區(qū)別。本文的參考資料也大多是Swin Transformer的。
這篇文章會從視頻的角度來解讀Swin Transformer。
2 模型介紹
2.1 整體架構(gòu)
2.1.1 backbone
Video Swin Transformer的backbone的整體架構(gòu)和Swin Transformer大同小異,多了一個時間維度TTT,在做Patch Partition的時候會有個時間維度的patch size。
以圖2-1為例,輸入為一個尺寸為T×H×W×3T \times H \times W \times 3T×H×W×3的視頻,通常還會有個batch size,這里省略掉了。TTT一般設(shè)置為32,表示從視頻的所有幀中采樣得到32幀,采樣的方法可以自行選擇,不同任務(wù)可能會有不同的采樣方法,一般為等間隔采樣。這里其實(shí)也就天然限制了模型不能處理和訓(xùn)練數(shù)據(jù)時長相差太多的視頻。通常視頻分類任務(wù)的視頻會在10s左右,太長的視頻也很難分到一個類別里。
輸入經(jīng)過Patch Partition之后會變成一個T2×H4×W4×96\frac{T}{2} \times \frac{H}{4} \times \frac{W}{4} \times 962T?×4H?×4W?×96的向量。這是因?yàn)閜atch size在這里為(2,4,4)(2,4,4)(2,4,4),分別是時間,高度和寬度三個維度的尺寸,其中969696是因?yàn)?span id="ze8trgl8bvbq" class="katex--inline">2×4×4×3=962 \times 4 \times 4 \times 3 = 962×4×4×3=96,也就是一個patch內(nèi)的所有像素點(diǎn)的rgb三個通道的值。Patch Partition會在2.2中詳述。
Patch Partiion之后會緊跟一個Linear Embedding,這兩個模塊在代碼中是寫在一起的,可以參見PatchEmbed3D,就是直接用一個3D的卷積,用這個卷積充當(dāng)全連接。如果embedding的dim為96,那么經(jīng)過embedding之后的尺寸還是2×4×4×3=962 \times 4 \times 4 \times 3 = 962×4×4×3=96。
之后分別會經(jīng)過多個video swin transformer block和patch merging。video swin transformer是利用attention同一個window內(nèi)的特征進(jìn)行特征融合的模塊;patch merging則是用來改變特征的shape,可以當(dāng)作CNN模型當(dāng)中的pooling,不過規(guī)則不同,而且patch merging還會改變特征的dim,也就是CCC改變。整個過程模仿了CNN模塊中的下采樣過程,這也是為了讓模型可以針對不同尺度生成特征。淺層可以看到小物體,深層則著重關(guān)注大物體。
video swin transformer block的結(jié)構(gòu)如下圖2-2所示。
圖2-2 video swin transformer block結(jié)構(gòu)圖2-2的左和右是兩個不同的blocks,需要連在一起搭配使用。在圖2-1中的video swin tranformer block下方有×2\times 2×2或是×6\times 6×6這樣的符號,表示有幾個blocks,這必定是個偶數(shù),比如×2\times 2×2就表示圖2-2這樣1組blocks,×6\times 6×6就表示圖2-2這樣3組blocks相連。
不難看出,有兩種blocks,每個block都是先過一個LN(LayerNorm),再過一個MSA(multi-head self-attention),再過一個LN,最后過一個MLP(multilayer perceptron),其中有兩處使用了殘差模塊。殘差塊主要是為了緩解梯度彌散。
兩種blocks的區(qū)別在于前者的MSA是window MSA,后者是shifted-window MSA。前者是為了window內(nèi)的信息交流(局部),后者是為了window間的信息交流(全局)。這個會在2.2中進(jìn)行詳述。
2.1.2 head
backbone的作用是提取視頻的特征,真正來做分類的還是接在backbone后面的head,這個部分就很簡單了,就是一層全連接,代碼中使用的是I3DHead。順便還帶了AdaptiveAvgPool3d,這是用來將輸入變成適合全連接的shape的。這部分就不說了,沒啥說的。
2.2 模塊詳述
2.2.1 Patch Partition
下圖2-3是一段視頻中的8幀,每幀都被分成了8×8=648 \times 8=648×8=64個網(wǎng)格,假設(shè)每個網(wǎng)格的像素為4×44 \times 44×4,那么當(dāng)patch size為(1,4,4)(1, 4, 4)(1,4,4)時,每個小網(wǎng)格就是一個patch;當(dāng)patch size為(2,4,4)(2,4,4)(2,4,4)時,沒相鄰兩幀的同一個位置的網(wǎng)格組成一個patch。這里和vision tranformer中的劃分方式相同,只不過多了時間的概念。
2.2.2 3D Patch Merging
3D Patch Merging這一塊直接看代碼會比較好理解,它和swin transformer中的2D patch merging一模一樣,3D Patch Merging雖然多了時間維度,但是并沒有對時間維度做merging的操作,也就是輸出的時間維度不變。
x0 = x[:, :, 0::2, 0::2, :] # B T H/2 W/2 C x1 = x[:, :, 1::2, 0::2, :] # B T H/2 W/2 C x2 = x[:, :, 0::2, 1::2, :] # B T H/2 W/2 C x3 = x[:, :, 1::2, 1::2, :] # B T H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # B T H/2 W/2 4*C看代碼再結(jié)合圖就更好理解了。圖中每個顏色都是一個patch。
2.2.3 W-MSA
MSA(multihead self attention)的原理這里就不說了,不懂的可以參見搞懂HMM,這里主要來說一說這個window。W-MSA(window based MSA)相比于MSA多了一個window的概念,相比于vision transformer引入window的目的是減小計(jì)算復(fù)雜度,使得復(fù)雜度和輸入圖片的尺寸成線性關(guān)系。這里不推導(dǎo)復(fù)雜度的計(jì)算,有興趣的可以看Swin Transformer論文精讀,這里有很詳細(xì)的推導(dǎo),3D和2D的復(fù)雜度計(jì)算方法是一致的。
窗口的劃分方式如圖2-5所示,每個窗口的大小由window size決定。圖2-5的window size為(4,4,4)(4,4,4)(4,4,4)就表示在時間,高度和寬度的window尺寸都是4個patch,劃分后的結(jié)果如圖2-5右半所示。之后的attention每個window單獨(dú)做,window之間不互相干擾。
2.2.4 SW-MSA
由于W-MSA的attention是局部的,作者就提出了SW-MSA(shifted window based MSA)。
圖2-6 SW-MSA示意圖SW-MSA如圖2-6所示,圖中shift size為(2,2,2)(2,2,2)(2,2,2),一般shift size都是window size的一半,也就是(P2,M2,M2)(\frac{P}{2}, \frac{M}{2}, \frac{M}{2})(2P?,2M?,2M?)。shift了之后,window會往前,往右,往下分別移動對應(yīng)的size,目的是讓patch可以和不同window的patch做特征的融合,這樣多過幾層之后,也就相當(dāng)于做了全局的特征融合。
不過這里有一個問題,shift了之后,window的數(shù)量從原來的2×2×2=82 \times 2 \times 2=82×2×2=8變成了3×3×3=273 \times 3 \times 3=273×3×3=27。這帶來的弊端就是計(jì)算時窗口不統(tǒng)一會比較麻煩。為了解決這個問題,作者引入了mask,并將窗口的位置進(jìn)行了移動,使得進(jìn)行shift和不進(jìn)行shift的MSA計(jì)算方式相同,只不過mask不同。
圖2-7 shift window示意圖我用PPT畫了一下shift的過程,畫圖能力有限,能看懂就好。我們的目的是把圖2-6中最右側(cè)的27個windows變成和圖2-6中間那樣的8個window。我給每個window都標(biāo)了序號,標(biāo)序號的方式是從前往后,從上往下,從左往右。shift window的方法就是把左上角的移到右下角,把前面的移到后面。這樣一來,比如[27,25,21,19,9,7,3,1][27, 25, 21, 19, 9, 7, 3, 1][27,25,21,19,9,7,3,1]就組成了1個window,[18,16,12,10][18, 16, 12, 10][18,16,12,10]就組成了1個window,依此類推,一共有8個windows。平移的方式可以和上述的不同,只要保證可以把27個windows變成和8個windows的計(jì)算方式一樣即可。
這樣在每個window做self-attention的時候,需要加一層mask,可以說是引入了inductive bias。因?yàn)樵诮M合而成的window內(nèi),各個小window我們不希望他們交換信息,因?yàn)檫@不是圖像原有的位置,比如17和11經(jīng)過shift之后,會在同一個window內(nèi)做attention,但是11是從上面移下來的,只是為了計(jì)算的統(tǒng)一,并不是物理意義上的同一個window。有了mask就不一樣了,mask的目的是告訴17號窗口內(nèi)的每一個patch,只和17號窗口內(nèi)的patches做attention,不和11號窗口內(nèi)的做attention,依此類推其他。
mask的生成方法可以參見源碼,這里不細(xì)講,主要思路是就像圖2-7這樣,給每個patch一個window的編號,編號相同的patch之間mask為0,否則為-100。
def compute_mask(D, H, W, window_size, shift_size, device):img_mask = torch.zeros((1, D, H, W, 1), device=device) # 1 Dp Hp Wp 1cnt = 0for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0],None):for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1],None):for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2],None):img_mask[:, d, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, window_size) # nW, ws[0]*ws[1]*ws[2], 1mask_windows = mask_windows.squeeze(-1) # nW, ws[0]*ws[1]*ws[2]attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))return attn_mask如果window的大小為圖2-6中的(P,M,M)(P, M, M)(P,M,M)的話,attention mask就是一個(P×M×M,P×M×M)(P \times M \times M,P \times M \times M)(P×M×M,P×M×M)的矩陣,這是一個對稱矩陣,第iii行第jjj列就表示window中的第iii個patch和第jjj個patch的window編號是否是相同的,相同則為0,不同則為-100。對角線上的元素必為0。
有人認(rèn)為淺層的網(wǎng)絡(luò)需要SW-MSA,深層的就不需要了,因?yàn)闇\層已經(jīng)講全局的信息都交流了,深層不需要進(jìn)一步交流了。這種說法的確有一定的道理,但也要看網(wǎng)絡(luò)的深度和shift的尺寸。
2.2.5 Relative Position Bias
在上述的所有內(nèi)容中,都沒有涉及到位置的概念,也就是模型并不知道每個patch在圖片中和其他patches的位置關(guān)系是怎么樣的,最有也就是知道某幾個patch是在同一個window內(nèi)的,但window內(nèi)的具體位置也是不知道的,因此就有了Relative Position Bias。它是加在attention的部分的,下式(2?1)(2-1)(2?1)中的BBB就是Relative Position Bias。
Attention(Q,K,V)=Softmax(QKT/d+B)V(2-1)Attention(Q,K,V) = Softmax(QK^T/\sqrtze8trgl8bvbq + B)V \tag{2-1} Attention(Q,K,V)=Softmax(QKT/d?+B)V(2-1)
很多swin tranformer的文章都會將這個BBB是如何得到的,但是卻沒有講為什么要這樣生成BBB。其實(shí)只要知道了設(shè)計(jì)這個BBB的目的,就可以不用管是如何生成的了,甚至自己設(shè)計(jì)一種生成的方法都行。
BBB是為了表示一個windows內(nèi),每個patch的相對位置,給每個相對位置一個特殊的embedding值。其實(shí)也正是因?yàn)檫@個BBB的存在,SW-MSA才必須要有mask,因?yàn)镾W-MSA內(nèi)的patches可能來自于多個windows,相對位置不能按照這個方法給,如果BBB可以表示全圖的相對位置,那就不用這個mask了。
這個B和mask的shape是一致的,也是(P×M×M,P×M×M)(P \times M \times M,P \times M \times M)(P×M×M,P×M×M)的矩陣,第iii行第jjj列就表示window中的第jjj個patch相對于第iii個patch的位置。
下圖2-8是我畫的一個示意圖,即使是一個(2,2,2)(2, 2, 2)(2,2,2)的window,我也感到工作量太大,矩陣沒填滿,畫了幾個示意了一下。如果window size為(P,M,M)(P, M, M)(P,M,M)的話,那么相對位置狀態(tài)就會有(2P?1)×(2M?1)×(2M?1)(2P-1) \times (2M-1) \times (2M-1)(2P?1)×(2M?1)×(2M?1)種狀態(tài),我把(2,2,2)(2, 2, 2)(2,2,2)的window的27種相對位置狀態(tài)全都在圖2-8上寫出來了。
圖2-8 Relative Position Bias示意圖有了狀態(tài)之后,就只需要在BBB這個矩陣中將相對位置的狀態(tài)對號入座即可。這就是很多其他博客寫的相對位置坐標(biāo)相減,然后加個偏置,再乘個系數(shù)的地方。理解了為什么要這么做,看那些操作也就不會覺得奇怪了。
但最終使用的不是狀態(tài),而是狀態(tài)對應(yīng)的embedding值,這就需要有一個table來根據(jù)狀態(tài)查找embedding,這個embedding是模型訓(xùn)練出來的。
3 模型效果
作者在三個數(shù)據(jù)集上進(jìn)行了測試,分別是kinetics-400,kinetics-600和something-something v2。每個數(shù)據(jù)集上都有著state-of-art的表現(xiàn)。
表3-1 kinetics-400模型對比指標(biāo) 表3-2 kinetics-600模型對比指標(biāo) 表3-3 something-something v2模型對比指標(biāo)參考資料
[1] Video Swin Transformer
[2] Swin-Transformer網(wǎng)絡(luò)結(jié)構(gòu)詳解
[3] Swin Transformer論文精讀
[4] Swin Transformer從零詳細(xì)解讀
[5] https://github.com/SwinTransformer/Video-Swin-Transformer
總結(jié)
以上是生活随笔為你收集整理的论文阅读 - Video Swin Transformer的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 指令系统——数据寻址(2)(详解)
- 下一篇: Promise学习笔记