别再无聊地吹捧了,一起来动手实现MAE玩玩吧!
?作者 | CW不要無聊的風(fēng)格
研究方向 | 目標(biāo)檢測、大規(guī)模預(yù)訓(xùn)練模型
前言
只要你不是與世隔絕的深度煉丹者,應(yīng)該都知道前陣子愷明大神的佳作 MAE(Masked Autoencoders Are Scalable Vision Learners),自雙 11 那天掛到 arXiv 后,江湖上就開始大肆吹捧:'yyds'、'best paper 預(yù)定' 什么的滿天飛.. 造成這一現(xiàn)象最主要原因還是大神本身的光環(huán)所致,另外就是大家看到 paper 中展示的 mask 掉圖像中這么多部分(75%~95%)后模型仍能重建回不錯(cuò)的效果,難免無腦地拍手叫好(沒有冒犯的意思,別打我,CW 當(dāng)時(shí)也沒忍住拍手叫好了..)。
但是,作為 coder,只動嘴不覺得很無聊么!?CW 可不要無聊的風(fēng)格,既然 MAE 看起來這么牛逼,那就干脆動手碼一碼,實(shí)現(xiàn)下看它效果如何嘛。盡管沒有代碼還沒有開源,但方法本身足夠簡單,paper 也大致描述了下實(shí)現(xiàn)的方法。
于是,CW 出于愛玩的心態(tài)并結(jié)合一貫以來不無聊的風(fēng)格,在周末的一個(gè)下午,去了我最愛的 cafe(上了新豆子,好棒!),邊碼邊喝咖啡,自己實(shí)現(xiàn)了 MAE,還挺有意思的。
本文會先講述 MAE 的原理與方法,然后針對 paper 中的實(shí)驗(yàn)現(xiàn)象談?wù)勛约旱睦斫?#xff0c;最后再分享與解析自己的源碼實(shí)現(xiàn)。
概述
MAE 的做法可以用一句話概述:以一定比例隨機(jī) mask 掉圖片中的一些圖像塊(patch)然后重建這些部分的像素值。
主要特點(diǎn)有兩個(gè):
1. 非對稱的編、解碼器設(shè)計(jì);
2. 使用較高(如75%)的掩碼率(mask比例)。
第 1 點(diǎn)所述的“非對稱”主要體現(xiàn)在輸入形式與網(wǎng)絡(luò)結(jié)構(gòu)上:編碼器(Encoder)僅對可見(unmasked)的圖像塊進(jìn)行編碼,而解碼器(Decoder)的輸入則是所有的圖像塊;同時(shí),Decoder 可以是比較輕量的(比如 Encoder 通常是多層堆疊的 Transformer,而 Decoder 僅需較少層甚至 1 層就 ok)。這也表明 Encoder 與 Decoder 之間是解耦的。
第 2 點(diǎn)是該工作的一個(gè)重要發(fā)現(xiàn):不同于 NLP,在 CV 中可能要配合較高的 mask 比例才能作為“有效”的自監(jiān)督代理任務(wù)。“有效”指的是任務(wù)本身足夠困難,這樣模型才能學(xué)到有效的潛在特征表示。
由于 Encoder 僅處理 unmasked 的 patch(占所有輸入的少數(shù)),因此,盡管其本身網(wǎng)絡(luò)結(jié)構(gòu)比較重載,但依然能夠高效訓(xùn)練,特別是對于大模型,能夠加速 3 倍以上,同時(shí)配合較高的掩碼率,還能夠漲點(diǎn)。
我們知道,MAE 的方法屬于掩碼自編碼(Masked Autoencoding)范疇,那么,為何要用這種玩法呢?
好奇心:Why Masked Autoencoding?
得益于硬件發(fā)展與算力的支持,現(xiàn)在的模型越玩越大,大模型由于參數(shù)量眾多,因此也很容易過擬合一般規(guī)模的數(shù)據(jù)集。于是,再這么玩下去就需要更大量的數(shù)據(jù),而這么大量的標(biāo)注數(shù)據(jù)人工成本是很高的,作者也不禁 diss 一波:很多人(你們說是誰呢?會不會是姓 G 的那位呢?)吶,還用他們私有的數(shù)據(jù)集關(guān)起門來偷偷玩,不肯和大家分享:
Aided by the rapid gains in hardware, models today can easily overfit one million images and begin to demand hundreds of millions of—often?publicly inaccessible—labeled images.
所以說,這么玩下去成本太高了,玩不起呀,于是就想方設(shè)法地開辟出了新的玩法:自監(jiān)督預(yù)訓(xùn)練。其中,較為常見的一種模式就是 masked autoencoding,這種這玩法在 NLP 尤為火熱,大名鼎鼎的 BERT 在預(yù)訓(xùn)練中就是這么玩的:以一定比例 mask 掉輸入文本中的一些部分,讓模型去預(yù)測這批被 mask 掉的內(nèi)容。這樣,利用數(shù)據(jù)本身就可以作為監(jiān)督(模型要預(yù)測的目標(biāo)來源于數(shù)據(jù)本身,并非人工構(gòu)造),無需復(fù)雜的人工標(biāo)注。同時(shí),使用大量的數(shù)據(jù)讓擁有大規(guī)模參數(shù)量的模型能夠?qū)W到通用的知識,從而擁有良好的泛化能力。
以上談到的是預(yù)訓(xùn)練階段,當(dāng)模型實(shí)際要應(yīng)用于不同的下游任務(wù)時(shí),還要使用少量的標(biāo)注數(shù)據(jù)進(jìn)行微調(diào)(fine-tune),這樣才能夠真正應(yīng)對目標(biāo)任務(wù)。
按照以前的玩法,在面對不同的任務(wù)時(shí),我們都需要重新設(shè)計(jì)模型結(jié)構(gòu),然后用特定任務(wù)的全量標(biāo)注數(shù)據(jù)去進(jìn)行訓(xùn)練。而現(xiàn)在不用了,只要設(shè)計(jì)了合理的預(yù)訓(xùn)練任務(wù),讓大規(guī)模模型在大量的上游數(shù)據(jù)中完成了預(yù)訓(xùn)練,它就能學(xué)到“通用知識”,猶如“通才”;之后,在面對不同的任務(wù)時(shí),我們都可以利用這個(gè) pre-trained 大模型,在少量的下游數(shù)據(jù)中進(jìn)行二次學(xué)習(xí),讓其成為“專才”。
由于大模型參數(shù)量眾多,因此能夠很快擬合,在面對不同任務(wù)時(shí)都能夠高效學(xué)習(xí)(相對地,正是由于模型參數(shù)太多了,因此很容易過擬合到下游訓(xùn)練集,反而喪失了泛化能力,這也是 fine-tune 玩法的一大毛病)。
靈魂拷問:Why Masked Autoencoding In CV Lags Behind NLP?
OK,我們知道了 mask 這種玩法在 NLP 很流行,那為什么在 CV 中卻比較冷門呢?作者也向大家發(fā)起了靈魂拷問:
progress of autoencoding methods in vision lags behind NLP.?
We ask:?what makes masked autoencoding different between vision and language?
好吧,看沒人回答,作者只能自我深刻分析(這樣才能把故事講完),最終提煉出以下三點(diǎn):
i). 架構(gòu)(architecture)差異
CV 和 NLP 的網(wǎng)絡(luò)架構(gòu)不一致,前者在過去一直被 CNN 統(tǒng)治,它基于方正的局部窗口來操作,不方便集成像 mask token 以及 position embedding 這類帶有指示性的可學(xué)習(xí)因子。不過,這個(gè) gap 現(xiàn)在看來應(yīng)該可以解決了,因?yàn)?ViT(Vision Transformer)已經(jīng)在 CV 界大肆虐殺,風(fēng)頭很猛..
ii). 信息密度(information density)不同
圖像和語言的信息密度是不一樣的。語言是人類創(chuàng)造的,本身就是高度語義和信息密集的,于是將句子中的少量詞語抹去再讓模型去預(yù)測這些被抹去的詞本身就已經(jīng)是比較困難的任務(wù)了;而對于圖像則相反,它在空間上是高度冗余的,對于圖片中的某個(gè)部分,模型很容易由其相鄰的圖像塊推斷出來(你想想看插值的道理),不需要大量的高級語義信息。
因此,在 CV 中,如果要使用 mask 這種玩法,就應(yīng)該要 mask 掉圖片中的較多的部分,這樣才能使任務(wù)本身具有足夠的挑戰(zhàn)性,從而使模型學(xué)到良好的潛在特征表示。
iii). 解碼的目標(biāo)不一致
CV 和 NLP 在解碼器的設(shè)計(jì)上應(yīng)該有不一樣的考慮:NLP 解碼輸出的是對應(yīng)被 mask 掉的詞語,本身包含了豐富的語義信息;而 CV 要重建的是被 mask 掉的圖像塊(像素值),是低語義的。
因此,NLP 的解碼器可以很簡單,比如 BERT,嚴(yán)格來說它并沒有解碼器,最后用 MLP 也可以搞定。因?yàn)閬碜跃幋a器的特征也是高度語義的,與需要解碼的目標(biāo)之間的 gap 較小;而 CV 的解碼器設(shè)計(jì)則需要“謹(jǐn)慎”考慮了,因?yàn)樗獙?strong>來自編碼器的高級語義特征解碼至低級語義層級。
基于以上三點(diǎn)的自我分析(作者很入戲,估計(jì)還喝了口咖啡),靈感一來,MAE 就被 present 出來了:
Driven by this analysis, we present a simple, effective, and scalable form of a masked autoencoder (MAE) for visual representation learning.
喲!你瞧,simple, effective, and scalable,作者自己都很滿意~
什么!?你說他自吹自擂你不服?好,愷明大神立馬放一波效果圖讓你開開眼界:
以上每 3 列為一組,每組中的左列是 mask 掉原圖 80% 部分的效果圖,中列是模型重建的效果,右列是原圖。
什么?還要用數(shù)字說話?好,自個(gè)兒看:
With a vanilla ViT-Huge model, we achieve?87.8%?accuracy when finetuned on ImageNet-1K. This?outperforms all previous results that use only ImageNet-1K data.
具體方法
是時(shí)候來談?wù)?MAE 的具體方法了。雖然前面鋪墊了那么多,但是 CW 認(rèn)為這是有必要的。教員也告訴我們,看問題要有廣度、深度、精度:先縱觀歷史有全局認(rèn)識,再結(jié)合當(dāng)前情況深入分析,從而抓住問題的重點(diǎn),最終才能追溯到本質(zhì)。
結(jié)合前面的敘述,我們知道 MAE 方法的特點(diǎn)主要有:高掩碼率的隨機(jī) mask 策略、非對稱的編、解碼器設(shè)計(jì)以及重建的目標(biāo)是像素值。下面,就請各位朋友和 CW 一起來具體看看其中的每個(gè)部分。
5.1 Mask 策略
首先,沿襲 ViT 的做法,將圖像分成一塊塊(ViT 中是 16x16 大小)不重疊的 patch,然后使用服從均勻分布(uniform distribution)的采樣策略對這些 patches 隨機(jī)采樣一部分,同時(shí) mask 掉余下的另一部分。被 mask 掉的 patches 占所有 patches 的大部分(實(shí)驗(yàn)效果發(fā)現(xiàn)最好的比例是 75%),它們不會輸入到 Encoder。
OK,策略很簡單,那么這樣做有什么好處呢?
首先,patch 在圖像中是服從均勻分布來采樣的,這樣能夠避免潛在的“中心歸納偏好”(也就是避免 patch 的位置大多都分布在靠近圖像中心的區(qū)域);其次,采用高掩碼比例(mask 掉圖中大部分 patches)能夠防止模型輕易地根據(jù)鄰近的可見 patches 推斷(原文是 extrapolation,外推,這詞有點(diǎn)高級..)出這些掩碼塊;最后,這種策略還造就了稀疏的編碼器輸入,因?yàn)?Encoder 只處理可見的 patches,于是能夠以更低的代價(jià)訓(xùn)練較大規(guī)模的 Encoder,因?yàn)?strong>計(jì)算量和內(nèi)存占用都減少了。
別看這 mask 策略好像挺簡單的,但卻是至關(guān)重要的一個(gè)部分,因?yàn)槠錄Q定了預(yù)訓(xùn)練代理任務(wù)是否具有足夠的挑戰(zhàn)性,從而影響著 Encoder 學(xué)到的潛在特征表示 以及 Decoder 重建效果的質(zhì)量。
下圖是作者在 paper 中展示的基于不同 mask 策略進(jìn)行訓(xùn)練后模型的表現(xiàn)。我們?nèi)庋劭梢?#xff0c;以上提到的隨機(jī)(服從均勻分布)采樣策略下模型的表現(xiàn)最好。注意,圖中的 'block' 策略由于 mask 掉的是大塊的 patch,因此 mask 比例設(shè)置了 50%,以達(dá)到和其它策略 mask 掉的部分占原圖比例較為接近的效果。
還有呀,mask 比例也是很重要的,CW 在前文也提到過,在 CV 中,只有 mask 掉圖中較多的部分才能形成具有挑戰(zhàn)性的任務(wù)。作者實(shí)驗(yàn)發(fā)現(xiàn),無論是 fine-tune 還是 linear-probe 下,75% 左右的 mask 比例都是比較好的一個(gè)選擇。
5.2 Encoder
記住最重要的一點(diǎn),Encoder 僅處理可見(unmasked)的 patches。Encoder 本身可以是 ViT 或 ResNet(其它 backbone 也 ok,就等你去實(shí)現(xiàn)了,大神給了你機(jī)會),至于如何將圖像劃分成 patch 嘛,使用 ViT 時(shí)的套路是這樣的:
先將圖像從(B,C,H,W)reshape 成(B,N,PxPxC),其中 N 和 P 分別為 patch 數(shù)量 和 patch 大小(),也就是將 3 通道的圖像轉(zhuǎn)換成 N 個(gè) 維度大小為 PxPxC 的向量;然后,通過線性映射(linear projection,可以是全連接層)將其嵌入(embed)到指定的維度空間大小,記為 'dim'(從 PxPxC project 到 dim),轉(zhuǎn)換成為 token(B,N,dim);最后再加上位置嵌入(position embedding),從而為各個(gè) patch 添加位置信息。位置嵌入是所有圖像共享的、可學(xué)習(xí)的,shape 與 每張圖的 token 相對應(yīng),即:(N,dim)。
由于 unmasked 的 patches 所有 patches 的少數(shù),因此可以訓(xùn)練很大的 Encoder,因?yàn)橛?jì)算和空間要求都減少了。
5.3 Decoder
Decoder 嘛.. 就別想著偷懶了,它不僅需要處理經(jīng)過 Encoder 編碼的 unmasked 的 tokens,還需要處理 mask tokens。但請注意,mask token 并非由之前 mask 掉的 patch 經(jīng)過 embedding 轉(zhuǎn)換而來,而是可學(xué)習(xí)的、所有 masked patch 都共享的 1 個(gè)向量,對,僅僅就是 1 個(gè)!
那么你會問:這樣如何區(qū)分各個(gè) maked patch 所對應(yīng)的 token 呢?
別忘了,我們還有 position embedding 嘛!如同在 Encoder 中的套路一樣,這里對于 mask token 也需要加入位置信息。position emebdding 是每個(gè) masked patch 對應(yīng) 1 個(gè),shape 是(N',dim),其中 N' 是 masked patch 的數(shù)量。但 mask token 只有 1 個(gè)怎么辦是不是?簡單粗暴——“復(fù)制”多份即可,使得每個(gè) masked patch 都對應(yīng) 1 個(gè) mask token,這樣就可以和 position embedding 進(jìn)行相加了。
另外,Decoder 僅僅是在預(yù)訓(xùn)練任務(wù)為了重建圖像而存在,而我們的下游任務(wù)形式多樣,因此實(shí)際應(yīng)用時(shí)很可能沒 Decoder 什么事了(和它 say byebye 咯~)。所以,Decoder 的設(shè)計(jì)和 Encoder 是解耦的,Decoder 可以設(shè)計(jì)得簡單、輕量一些(比 Encoder 更窄、更淺。窄:對應(yīng)通道數(shù);淺:對應(yīng)深度),畢竟真正能學(xué)習(xí)到潛在特征表示的是 Encoder。
這樣,盡管 Decoder 要處理的 token 數(shù)很多(全量token,而 Encoder 僅處理 unmasked 的部分),但其本身輕量,所以還是能夠高效計(jì)算。再結(jié)合 Encoder 雖然本身結(jié)構(gòu)重載(相對 Decoder 來說),但其處理的 token 較少,這樣,整體架構(gòu)就十分 efficient 了,漂亮~!
5.4 任務(wù)目標(biāo):重建像素值
MAE 預(yù)訓(xùn)練任務(wù)的目標(biāo)是重建像素值,并且僅僅是 masked patch 的像素值,也就是僅對 mask 掉的部分計(jì)算 loss,而 loss 就是很大眾的 MSE。為何僅計(jì)算 mask 部分的 loss?實(shí)驗(yàn)結(jié)果發(fā)現(xiàn)這樣做模型的性能會更好,而如果對所有 patches 都計(jì)算 loss 的話會掉點(diǎn):
Computing the loss only on masked patches differs from traditional denoising autoencoders that compute the loss on all pixels. This choice is purely?result-driven:?
computing the loss on all pixels leads to a slight decrease in accuracy (e.g., ~0.5%).
那么模型是如何去預(yù)測 masked patch 的像素值并計(jì)算 loss 的呢?具體來說,就是:
在 Decoder 解碼后的所有 token 中取出 mask tokens(在最開始 mask 掉 patch 的時(shí)候可以先記錄下這些 masked 部分的索引),將這些 mask tokens 送入全連接層,將輸出通道映射到 1 個(gè) patch 的像素?cái)?shù)量(PxPxC),也就是輸出的 shape 是:(B,N',PxPxC),其中的每個(gè)值就代表預(yù)測的像素值。最后,以之前 mask 掉的 patch 的像素值作為 target,與預(yù)測結(jié)果計(jì)算 MSE loss。
另外,作者提到使用歸一化的像素值作為 target 效果更好,能夠提升學(xué)到的表征的質(zhì)量。這里的歸一化做法是:計(jì)算每個(gè) patch 像素值的均值與標(biāo)準(zhǔn)差,然后用均值與標(biāo)準(zhǔn)差去歸一化對應(yīng)的 patch 像素。
5.5 Pipeline
OK,解析完 MAE 的各部分結(jié)構(gòu),現(xiàn)在 CW 就將它們串起來:
1. 將圖像劃分成 patches:(B,C,H,W)->(B,N,PxPxC);
2. 對各個(gè) patch 進(jìn)行 embedding(實(shí)質(zhì)是通過全連接層),生成 token,并加入位置信息(position embeddings):(B,N,PxPxC)->(B,N,dim);
3. 根據(jù)預(yù)設(shè)的掩碼比例(paper 中提倡的是 75%),使用服從均勻分布的隨機(jī)采樣策略采樣一部分 token 送給 Encoder,另一部分“扔掉”(mask 掉);
4. 將 Encoder 編碼后的 token 與 加入位置信息后的 mask token 按照原先在 patch 形態(tài)時(shí)對應(yīng)的次序拼在一起,然后喂給 Decoder 玩(如果 Encoder 編碼后的 token 的維度與 Decoder 要求的輸入維度不一致,則需要先經(jīng)過 linear projection 將維度映射到符合 Decoder 的要求);
4. Decoder 解碼后取出 mask tokens 對應(yīng)的部分送入到全連接層,對 masked patches 的像素值進(jìn)行預(yù)測,最后將預(yù)測結(jié)果與 masked patches 進(jìn)行比較,計(jì)算 MSE loss。
實(shí)驗(yàn)理解
這部分給大家 show 下 paper 中的部分實(shí)驗(yàn)結(jié)果,并針對其中一些現(xiàn)象談?wù)勛约旱睦斫狻?/p>
6.1 Mask 比例
前文也多次談到,mask 比例較高才能形成具有挑戰(zhàn)性的預(yù)訓(xùn)練任務(wù),模型才更有機(jī)會學(xué)到更好的潛在特征表示。由上圖中的實(shí)驗(yàn)結(jié)果也可以看到,無論是在 fine-tune 還是 linear probe 的玩法中,mask 比例逐漸升高(但不過份)時(shí),模型性能都會更好。
但是,fine-tune 和 linear probe 的結(jié)果還是有所區(qū)別的:linear probe 幾乎是線性增漲的趨勢,而 fine-tune 則是 mask 比例在 30%~40% 之間激增,而后就傾向于飽和了。
So,為啥會醬捏?
CW 覺得,linear probe 之所以沒有那么快飽和,和其本身的玩法相關(guān)——僅調(diào)整模型最后的幾層分類頭(fix 住其它部分,如 Encoder)。因此,mask 比例越高,在預(yù)訓(xùn)練時(shí)得到的 Encoder 就越強(qiáng),但這部分在下游任務(wù)中是不能夠再被訓(xùn)練的了,所以其性能就隨著 mask 比例的增加呈線性增漲的趨勢。
相對地,fine-tune 時(shí),還能夠繼續(xù)訓(xùn)練 Encoder 的參數(shù)去適配下游任務(wù),因此在 mask 比例超過一定程度后,對于下游任務(wù)的性能提升就不那么明顯了。
6.2 Mask 采樣策略
作者通過實(shí)驗(yàn)比較,最終選擇了服從均勻分布的隨機(jī)采樣(作者稱其為 'random')策略,以上是詳細(xì)的實(shí)驗(yàn)結(jié)果。
可以觀察出,block-wise 策略由于掩蓋掉的圖像塊區(qū)域太大了,因此在高于 50% 的 mask 比例下效果就不好(因?yàn)槟惚旧砭驼诘脧V,現(xiàn)在還要遮得多,太難了吧..)。
而對于 grid 策略,作者說,這種方式在訓(xùn)練時(shí)能夠?qū)?shù)據(jù)擬合得很好,但實(shí)際學(xué)到的特征表示泛化性其實(shí)是比較弱的。
由此可以說明,代理任務(wù)設(shè)計(jì)得太困難(對應(yīng) block-wise)或太簡單(對應(yīng) grid)都不行,要適當(dāng)(對應(yīng) random)才好,此乃中庸之道~
6.3 Decoder 的設(shè)計(jì)
作者還探究了 Decoder 的設(shè)計(jì)。上圖展示了不同的 Decoder 深度(Transformer 層數(shù))和寬度(通道數(shù))對于 fine-tune 和 linear probe 在 ImageNet-1K 下游任務(wù)中的表現(xiàn)。
可以發(fā)現(xiàn),Decoder 的深度和寬度對于 linear probe 有較為明顯的影響,但對于 fine-tune 的影響卻不那么突出。
So,為啥會醬捏(again)?
想一想,Decoder 更深和更寬時(shí),會發(fā)生什么?
(自問自答):當(dāng) Decoder 更深/寬時(shí),它本身會擁有更強(qiáng)的重建能力,這樣就使得在預(yù)訓(xùn)練時(shí) Encoder 能更專注于提取抽象語義層級的特征,專心做事了,產(chǎn)生的質(zhì)量也就更好了。也就是說,Encoder 在提取良好特征方面更專業(yè)了。
OK,了解了以上這點(diǎn)沒錯(cuò),但這種效應(yīng)是同樣作用于 linear probe 和 fine-tune 的,那么為何會造成不同的影響程度呢?
進(jìn)一步探究,其實(shí)還是與它們各自的玩法相關(guān):
linear probe 是完全繼承預(yù)訓(xùn)練 Encoder 的玩法(因其僅調(diào)最后幾層分類頭),而 fine-tune 在下游任務(wù)中仍能夠繼續(xù)調(diào)整 Encoder 的參數(shù)。于是,預(yù)訓(xùn)練時(shí)得到的 Encoder 牛不牛逼,對 linear probe 產(chǎn)生的影響會更大一些。
以上的話太白了,有點(diǎn) low,再裝裝逼:
究其本質(zhì),其實(shí)是預(yù)訓(xùn)練任務(wù)(圖像重建)與下游任務(wù)(圖像識別)之間存在著 gap!
fine-tune 時(shí)由于能夠調(diào)整 Encoder 去適配圖像識別任務(wù),因此預(yù)訓(xùn)練對其影響程度就相對沒那么大了。
6.4 Mask token 為何被 Encoder “拋棄”?
我們知道,在 MAE 中,Encoder 僅玩 unmasked 的 tokens。那么,如果它也玩 mask tokens 會怎樣呢?
你們別說:肯定會掉點(diǎn)嘛,不然作者干嘛不玩?
給點(diǎn)面子..
是的,如上圖中的實(shí)驗(yàn)結(jié)果顯示,會掉點(diǎn)(汗~)。原因也很直白:因?yàn)樵?strong>下游任務(wù)中并不存在這些 mask tokens,上、下游任務(wù)之間存在 gap(這點(diǎn)在當(dāng)年 BERT 出道時(shí)已經(jīng)暴露了出來)。如果 Encoder 也對 mask tokens 進(jìn)行編碼,會進(jìn)一步將這種 gap 的影響“擴(kuò)散”至下游任務(wù)中造成影響。
6.5 各種重建目標(biāo)的比較
MAE 的重建目標(biāo)是 mask patches 的像素值。同時(shí),作者在 paper 中還提到,如果預(yù)測的是歸一化(具體做法 CW 在上文中有描述)的像素值,那么效果會更好。另外,作者還和 BEiT 那種預(yù)測 token 的方式 以及 PCA 的方式(對 patch 空間實(shí)施 PCA 并預(yù)測最大的因子)進(jìn)行了比較:
可以發(fā)現(xiàn),預(yù)測歸一化像素值的方式最強(qiáng),BEiT 那種 token 的方式也差不多,那么,這種現(xiàn)象說明了什么呢?
回顧下前文 CW 提到的,這里歸一化像素值的做法是分別針對每個(gè) patch 使用它們獨(dú)立統(tǒng)計(jì)出來的均值與方差去歸一化的,這就會將各個(gè) patch 歸一化到不同的表示空間,從而分成不同的“簇”,于是各個(gè) patch 之間的差異性就更強(qiáng),形成了高頻信息,相當(dāng)于將各個(gè) patch 構(gòu)造成了邊緣與紋理,從整體圖像看來,對比度更高。從而使得模型更有針對性地學(xué)習(xí)各個(gè) patch 的特征模式。同時(shí),數(shù)值上由于做了歸一化,因此又不會使得模型在這方面有所偏倚。
至于 token 的方式是照搬 NLP 的玩法,是高度離散化和語義化的,一個(gè)字的差異也可能導(dǎo)致詞語之間的含義發(fā)生重大變化,本身就是高頻東西。
因此,究其本質(zhì):高頻信息才是王道!
6.6 數(shù)據(jù)增強(qiáng)
大家都知道,玩 CV 嘛肯定離不開數(shù)據(jù)增強(qiáng),于是作者探究了這老套路對于 MAE 方法的影響:
由上圖中的實(shí)驗(yàn)結(jié)果可知,這老套路果然還是有好處的。但是可以看到,不做隨機(jī)縮放(fixed size)和隨機(jī)縮放(rand size)的效果其實(shí)差不多,而采用色彩擾動(color jit)卻反而比簡單的 crop 還菜,有意思~
稍微想一下,這應(yīng)該是 MAE 本身 masking 的做法已經(jīng)是一種數(shù)據(jù)增強(qiáng)手段了,因此不需要“過份”的額外數(shù)據(jù)增強(qiáng)就能取得較好的效果(比如 color jit,本身就 mask 掉圖像的一些部分了,還來擾亂原本的像素值,模型當(dāng)然覺得不好搞啊..)。
6.7 干倒 linear probe
linear probe 一直是很流行的玩法,但通過上面的實(shí)驗(yàn)結(jié)果我們可以發(fā)現(xiàn),它與 fine-tune 之間總是存在著“不協(xié)同”的結(jié)果,比如前面說到的 Decoder 的深度和寬度對 linear probe 的影響挺大但對于 fine-tune 來說卻并不那么事關(guān)緊要。
于是,作者不禁懷疑起 linear probe 這種玩法的道理?!皺?quán)衡”了 linear probe 和 fine-tune 之間的做法,作者設(shè)計(jì)出一種 'partial fine-tuning' 的玩法:僅調(diào)整 Encoder 的最后幾層但 fix 住其它部分。如上圖所示,調(diào)整 0 個(gè) block 相當(dāng)于是 linear probe,而調(diào)整所有 24 個(gè) blocks 就是 fine-tuning 的玩法。
可以看到,對于 MAE,僅調(diào)整 1 個(gè) block 就可以將 acc 從73.5%(linear probe)漲到81%,并且對于 MOCO v3 也一樣可以漲點(diǎn)。
另外,MAE 在 partial fine-tuning 的方式下優(yōu)于 MOCO v3,這也說明 MAE 學(xué)到的特征非線性更強(qiáng),于是當(dāng)可以調(diào)整非線性頭部時(shí)效果就更好。
在這里,作者認(rèn)為 linear probe 有必要去“面壁思過”一下,因?yàn)樗@種玩法沒有去捕捉一些強(qiáng)大但非線性的特征,而這卻恰恰是深度學(xué)習(xí)所更應(yīng)該重視和擁有的。
于是,這些現(xiàn)象都向我們表明:linear probe 并非是唯一的、正確地評估模型學(xué)到的表征質(zhì)量的方式。并且,作者后續(xù)還進(jìn)行了 detection 與 segmentation 相關(guān)的實(shí)驗(yàn),從而在 linear probe 的玩法中學(xué)到的特征也并非是和遷移學(xué)習(xí)性能強(qiáng)相關(guān)的。
(其實(shí)就是想偷偷告訴你們:別被 linear probe 帶偏了哦~)
開局:源碼實(shí)現(xiàn)
終于到好玩的部分了,以上那些都是吹水,coder 還得動手寫代碼才好玩!
官方?jīng)]有開源,但是 MAE 本身的方法足夠簡單,因此 CW 就自己腦洞了下,試著按照 paper 描述的(當(dāng)然,還結(jié)合了自己的風(fēng)格)去實(shí)現(xiàn)。
(ps: 以下代碼基于Pytorch 框架,僅供娛樂使用)
先來看看 MAE 模型的初始化:
class?MAE(nn.Module):def?__init__(self,?encoder,?decoder_dim,?mask_ratio=0.75,?decoder_depth=1,?num_decoder_heads=8,?decoder_dim_per_head=64):super().__init__()assert?0.?<?mask_ratio?<?1.,?f'mask?ratio?must?be?kept?between?0?and?1,?got:?{mask_ratio}'#?Encoder(這里?CW?用?ViT?實(shí)現(xiàn))self.encoder?=?encoderself.patch_h,?self.patch_w?=?encoder.patch_h,?encoder.patch_w#?由于原生的 ViT 有 cls_token,因此其 position embedding 的倒數(shù)第2個(gè)維度是:#?實(shí)際劃分的?patch?數(shù)量加上?1個(gè)?cls_tokennum_patches_plus_cls_token,?encoder_dim?=?encoder.pos_embed.shape[-2:]#?Input?channels?of?encoder?patch?embedding:?patch?size**2?x?3#?這個(gè)用作預(yù)測頭部的輸出通道,從而能夠?qū)?patch?中的所有像素值進(jìn)行預(yù)測num_pixels_per_patch?=?encoder.patch_embed.weight.size(1)# Encoder-Decoder:Encoder 輸出的維度可能和 Decoder 要求的輸入維度不一致,因此需要轉(zhuǎn)換self.enc_to_dec?=?nn.Linear(encoder_dim,?decoder_dim)?if?encoder_dim?!=?decoder_dim?else?nn.Identity()#?Mask?token#?社會提倡這個(gè)比例最好是?75%self.mask_ratio?=?mask_ratio# mask token 的實(shí)質(zhì):1個(gè)可學(xué)習(xí)的共享向量self.mask_embed?=?nn.Parameter(torch.randn(decoder_dim))# Decoder:實(shí)質(zhì)就是多層堆疊的 Transformerself.decoder?=?Transformer(decoder_dim,decoder_dim?*?4,depth=decoder_depth,?num_heads=num_decoder_heads,dim_per_head=decoder_dim_per_head,?)#?在?Decoder?中用作對?mask?tokens?的?position?embedding#?Filter?out?cls_token?注意第1個(gè)維度去掉?cls_tokenself.decoder_pos_embed?=?nn.Embedding(num_patches_plus_cls_token?-?1,?decoder_dim)#?Prediction?head?輸出的維度數(shù)等于1個(gè)?patch?的像素值數(shù)量self.head?=?nn.Linear(decoder_dim,?num_pixels_per_patch)接下來,CW 會分各部分進(jìn)行解析,下面一起來看看咯(看完你們自己動手寫寫,會更好玩)~
ps:以上 Encoder 部分的 ViT 和 Decoder 部分的 Transformer 的實(shí)現(xiàn)沒有什么特別的,和開源的主流實(shí)現(xiàn)一致,比較無聊,因此 CW 在下文中不會對這部分進(jìn)行解析(其實(shí)想偷個(gè)懶,哈哈哈~!)。發(fā)現(xiàn)不少朋友在評論區(qū)說還是忍不住想看這部分的代碼實(shí)現(xiàn),好吧..我放在文末附錄(代碼 200 行不到,我就費(fèi)事讓大家跳到 github 了)咯~
7.1 Patch Partition
如前文所述,我們首先需要將圖像劃分成 patch,劃分方式實(shí)質(zhì)就是維度的變換:
num_patches?=?(h?//?self.patch_h)?*?(w?//?self.patch_w) #?(b,?c=3,?h,?w)->(b,?n_patches,?patch_size**2?*?c) patches?=?x.view(b,?c,h?//?self.patch_h,?self.patch_h,?w?//?self.patch_w,?self.patch_w ).permute(0,?2,?4,?3,?5,?1).reshape(b,?num_patches,?-1)7.2 Masking
接下來,就是根據(jù)預(yù)設(shè)的 mask 比例采用服從均勻分布的策略隨機(jī)采樣一批 patches 喂給 Encoder,剩下的就 mask 掉:
#?根據(jù)?mask?比例計(jì)算需要?mask?掉的?patch?數(shù)量 #?num_patches?=?(h?//?self.patch_h)?*?(w?//?self.patch_w) num_masked?=?int(self.mask_ratio?*?num_patches)#?Shuffle:生成對應(yīng)?patch?的隨機(jī)索引 #?torch.rand()?服從均勻分布(normal?distribution) #?torch.rand()?只是生成隨機(jī)數(shù),argsort()?是為了獲得成索引 #?(b,?n_patches) shuffle_indices?=?torch.rand(b,?num_patches,?device=device).argsort() #?mask?和?unmasked?patches?對應(yīng)的索引 mask_ind,?unmask_ind?=?shuffle_indices[:,?:num_masked],?shuffle_indices[:,?num_masked:]#?對應(yīng) batch 維度的索引:(b,1) batch_ind?=?torch.arange(b,?device=device).unsqueeze(-1) #?利用先前生成的索引對?patches?進(jìn)行采樣,分為?mask?和?unmasked?兩組 mask_patches,?unmask_patches?=?patches[batch_ind,?mask_ind],?patches[batch_ind,?unmask_ind]7.3 Encode
OK,這時(shí)候我們就可以在 Encoder 中對 unmasked 的 patches 進(jìn)行編碼了。
當(dāng)然,我們得先對 unmasked patches 進(jìn)行 emebdding 轉(zhuǎn)換成 tokens,并且加上 position embeddings,從而為它們添加位置信息,然后才能是真正的編碼過程。至于編碼過程,實(shí)質(zhì)上就是扔給 Transformer 玩(query 和 key 玩一玩,玩出個(gè) attention 后再和 value 一起玩~):
#?將?patches?通過?emebdding?轉(zhuǎn)換成?tokens unmask_tokens?=?self.encoder.patch_embed(unmask_patches) #?為?tokens?加入?position?embeddings? #?注意這里索引加1是因?yàn)樗饕?對應(yīng)?ViT?的?cls_token unmask_tokens?+=?self.encoder.pos_embed.repeat(b,?1,?1)[batch_ind,?unmask_ind?+?1] #?真正的編碼過程 encoded_tokens?=?self.encoder.transformer(unmask_tokens)7.4 Decode
Encoder 玩完后輸出編碼后的 tokens,首先將編碼后的 tokens 和 添加了位置信息后的 mask tokens 按原先對應(yīng) patches 的次序拼起來,然后喂給 Decoder 解碼。需要注意的是,編碼后的 tokens 維度若與 Decoder 要求的輸入維度不一致,需要使用 linear projection 進(jìn)行轉(zhuǎn)換。
#?對編碼后的?tokens?維度進(jìn)行轉(zhuǎn)換,從而符合?Decoder?要求的輸入維度 enc_to_dec_tokens?=?self.enc_to_dec(encoded_tokens)#?由于?mask?token?實(shí)質(zhì)上只有1個(gè),因此要對其進(jìn)行擴(kuò)展,從而和?masked?patches?一一對應(yīng) #?(decoder_dim)->(b,?n_masked,?decoder_dim) mask_tokens?=?self.mask_embed[None,?None,?:].repeat(b,?num_masked,?1) #?為?mask?tokens?加入位置信息 mask_tokens?+=?self.decoder_pos_embed(mask_ind)#?將?mask?tokens?與?編碼后的?tokens?拼接起來 #?(b,?n_patches,?decoder_dim) concat_tokens?=?torch.cat([mask_tokens,?enc_to_dec_tokens],?dim=1) # Un-shuffle:恢復(fù)原先 patches 的次序 dec_input_tokens?=?torch.empty_like(concat_tokens,?device=device) dec_input_tokens[batch_ind,?shuffle_indices]?=?concat_tokens #?將全量?tokens?喂給?Decoder?解碼 decoded_tokens?=?self.decoder(dec_input_tokens)7.5 Loss Computation
取出解碼后的 mask tokens 送入頭部進(jìn)行像素值預(yù)測,然后將預(yù)測結(jié)果和 masked patches 比較,計(jì)算 MSE loss:
#?取出解碼后的?mask?tokens dec_mask_tokens?=?decoded_tokens[batch_ind,?mask_ind,?:] #?預(yù)測?masked?patches?的像素值 #?(b,?n_masked,?n_pixels_per_patch=patch_size**2?x?c) pred_mask_pixel_values?=?self.head(dec_mask_tokens) #?loss?計(jì)算 loss?=?F.mse_loss(pred_mask_pixel_values,?mask_patches)7.6 Reconstruction (Inference)
為了方便觀測重建效果,CW 將以上部分串起來在模型中集成了一個(gè)推理的方法:
@torch.no_grad def?predict(self,?x):self.eval()device?=?x.deviceb,?c,?h,?w?=?x.shape'''i.?Patch?partition'''num_patches?=?(h?//?self.patch_h)?*?(w?//?self.patch_w)#?(b,?c=3,?h,?w)->(b,?n_patches,?patch_size**2*c)patches?=?x.view(b,?c,h?//?self.patch_h,?self.patch_h,?w?//?self.patch_w,?self.patch_w).permute(0,?2,?4,?3,?5,?1).reshape(b,?num_patches,?-1)'''ii.?Divide?into?masked?&?un-masked?groups'''num_masked?=?int(self.mask_ratio?*?num_patches)#?Shuffle#?(b,?n_patches)shuffle_indices?=?torch.rand(b,?num_patches,?device=device).argsort()mask_ind,?unmask_ind?=?shuffle_indices[:,?:num_masked],?shuffle_indices[:,?num_masked:]#?(b,?1)batch_ind?=?torch.arange(b,?device=device).unsqueeze(-1)mask_patches,?unmask_patches?=?patches[batch_ind,?mask_ind],?patches[batch_ind,?unmask_ind]'''iii.?Encode'''unmask_tokens?=?self.encoder.patch_embed(unmask_patches)#?Add?position?embeddingsunmask_tokens?+=?self.encoder.pos_embed.repeat(b,?1,?1)[batch_ind,?unmask_ind?+?1]encoded_tokens?=?self.encoder.transformer(unmask_tokens)'''iv.?Decode'''enc_to_dec_tokens?=?self.enc_to_dec(encoded_tokens)#?(decoder_dim)->(b,?n_masked,?decoder_dim)mask_tokens?=?self.mask_embed[None,?None,?:].repeat(b,?num_masked,?1)#?Add?position?embeddingsmask_tokens?+=?self.decoder_pos_embed(mask_ind)#?(b,?n_patches,?decoder_dim)concat_tokens?=?torch.cat([mask_tokens,?enc_to_dec_tokens],?dim=1)#?dec_input_tokens?=?concat_tokensdec_input_tokens?=?torch.empty_like(concat_tokens,?device=device)#?Un-shuffledec_input_tokens[batch_ind,?shuffle_indices]?=?concat_tokensdecoded_tokens?=?self.decoder(dec_input_tokens)'''v.?Mask?pixel?Prediction'''dec_mask_tokens?=?decoded_tokens[batch_ind,?mask_ind,?:]#?(b,?n_masked,?n_pixels_per_patch=patch_size**2?x?c)pred_mask_pixel_values?=?self.head(dec_mask_tokens)#?比較下預(yù)測值和真實(shí)值mse_per_patch?=?(pred_mask_pixel_values?-?mask_patches).abs().mean(dim=-1)mse_all_patches?=?mse_per_patch.mean()print(f'mse?per?(masked)patch:?{mse_per_patch}?mse?all?(masked)patches:?{mse_all_patches}?total?{num_masked}?masked?patches')print(f'all?close:?{torch.allclose(pred_mask_pixel_values,?mask_patches,?rtol=1e-1,?atol=1e-1)}')'''vi.?Reconstruction'''recons_patches?=?patches.detach()#?Un-shuffle?(b,?n_patches,?patch_size**2?*?c)recons_patches[batch_ind,?mask_ind]?=?pred_mask_pixel_values#?模型重建的效果圖#?Reshape?back?to?image?#?(b,?n_patches,?patch_size**2?*?c)->(b,?c,?h,?w)recons_img?=?recons_patches.view(b,?h?//?self.patch_h,?w?//?self.patch_w,?self.patch_h,?self.patch_w,?c).permute(0,?5,?1,?3,?2,?4).reshape(b,?c,?h,?w)mask_patches?=?torch.randn_like(mask_patches,?device=mask_patches.device)#?mask?效果圖patches[batch_ind,?mask_ind]?=?mask_patchespatches_to_img?=?patches.view(b,?h?//?self.patch_h,?w?//?self.patch_w,?self.patch_h,?self.patch_w,?c).permute(0,?5,?1,?3,?2,?4).reshape(b,?c,?h,?w)return?recons_img,?patches_to_img出于娛樂的目的,CW 沒有考慮太多,快速寫下一個(gè)很 low 的推理 pipeline:
device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')#?讀入圖像并縮放到適合模型輸入的尺寸 from?PIL?import?Imageimg_raw?=?Image.open(os.path.join(BASE_DIR,?'mountain.jpg')) h,?w?=?img_raw.height,?img_raw.width ratio?=?h?/?w print(f"image?hxw:?{h}?x?{w}?mode:?{img_raw.mode}")img_size,?patch_size?=?(224,?224),?(16,?16) img?=?img_raw.resize(img_size) rh,?rw?=?img.height,?img.width print(f'resized?image?hxw:?{rh}?x?{rw}?mode:?{img.mode}') img.save(os.path.join(BASE_DIR,?'resized_mountain.jpg'))#?將圖像轉(zhuǎn)換成張量 from?torchvision.transforms?import?ToTensor,?ToPILImageimg_ts?=?ToTensor()(img).unsqueeze(0).to(device) print(f"input?tensor?shape:?{img_ts.shape}?dtype:?{img_ts.dtype}?device:?{img_ts.device}")#?實(shí)例化模型并加載訓(xùn)練好的權(quán)重 encoder?=?ViT(img_size,?patch_size,?dim=512,?mlp_dim=1024,?dim_per_head=64) decoder_dim?=?512 mae?=?MAE(encoder,?decoder_dim,?decoder_depth=6) weight?=?torch.load(os.path.join(BASE_DIR,?'mae.pth'),?map_location='cpu') mae.to(device)#?推理 #?模型重建的效果圖,mask?效果圖 recons_img_ts,?masked_img_ts?=?mae.predict(img_ts) recons_img_ts,?masked_img_ts?=?recons_img_ts.cpu().squeeze(0),?masked_img_ts.cpu().squeeze(0)#?將結(jié)果保存下來以便和原圖比較 recons_img?=?ToPILImage()(recons_img_ts) recons_img.save(os.path.join(BASE_DIR,?'recons_mountain.jpg'))masked_img?=?ToPILImage()(masked_img_ts) masked_img.save(os.path.join(BASE_DIR,?'masked_mountain.jpg'))人在 cafe 時(shí)間有限,CW 試著用 1 張圖片訓(xùn)練少輪迭代,該圖是我十月份到可可西里無人區(qū)拍攝的風(fēng)景,然后直接用訓(xùn)好的模型在這張?jiān)瓐D上進(jìn)行推理,以下是實(shí)驗(yàn)結(jié)果:
▲ 原圖
▲?mask(ratio=75%)圖
▲?模型重建效果
由于是在訓(xùn)練集上推理、肉眼也看不出來模型重建的效果圖與原圖的差別,因此并沒有太大的意義,但起碼保證了代碼可以跑通,模型可以成功擬合數(shù)據(jù),作為在 cafe 喝咖啡的附加娛樂項(xiàng)目還是能夠過把癮的。
附錄
Encoder 中 ViT 的實(shí)現(xiàn) & Decoder 中 Transformer 的實(shí)現(xiàn)如下:
End
圖像和語言是不同性質(zhì)的信號,圖像不像語言一樣天然是由一個(gè)個(gè)可分解的字詞組成,它是連續(xù)的信號。因此,為了更遵循圖像的“本性”,MAE 掩碼的時(shí)候是對整體圖像區(qū)域隨機(jī)掩碼,而非有意地對圖像做語義性的分割(比如有意地去 mask 掉一些物體或特定區(qū)域)。同樣地,MAE 重建的目標(biāo)是像素值,而非語義實(shí)體(什么圖像化 token 等)。
果然研究/解決一類問題還是要貼近其本質(zhì)才能更好地 work,并且能夠持久 work 的方法也通常是簡潔而非花里胡哨的,因?yàn)楸举|(zhì)就是最純真的東西,所以說,大道至簡。
特別鳴謝
感謝 TCCI 天橋腦科學(xué)研究院對于 PaperWeekly 的支持。TCCI 關(guān)注大腦探知、大腦功能和大腦健康。
更多閱讀
#投 稿?通 道#
?讓你的文字被更多人看到?
如何才能讓更多的優(yōu)質(zhì)內(nèi)容以更短路徑到達(dá)讀者群體,縮短讀者尋找優(yōu)質(zhì)內(nèi)容的成本呢?答案就是:你不認(rèn)識的人。
總有一些你不認(rèn)識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋梁,促使不同背景、不同方向的學(xué)者和學(xué)術(shù)靈感相互碰撞,迸發(fā)出更多的可能性。?
PaperWeekly 鼓勵高校實(shí)驗(yàn)室或個(gè)人,在我們的平臺上分享各類優(yōu)質(zhì)內(nèi)容,可以是最新論文解讀,也可以是學(xué)術(shù)熱點(diǎn)剖析、科研心得或競賽經(jīng)驗(yàn)講解等。我們的目的只有一個(gè),讓知識真正流動起來。
📝?稿件基本要求:
? 文章確系個(gè)人原創(chuàng)作品,未曾在公開渠道發(fā)表,如為其他平臺已發(fā)表或待發(fā)表的文章,請明確標(biāo)注?
? 稿件建議以?markdown?格式撰寫,文中配圖以附件形式發(fā)送,要求圖片清晰,無版權(quán)問題
? PaperWeekly 尊重原作者署名權(quán),并將為每篇被采納的原創(chuàng)首發(fā)稿件,提供業(yè)內(nèi)具有競爭力稿酬,具體依據(jù)文章閱讀量和文章質(zhì)量階梯制結(jié)算
📬?投稿通道:
? 投稿郵箱:hr@paperweekly.site?
? 來稿請備注即時(shí)聯(lián)系方式(微信),以便我們在稿件選用的第一時(shí)間聯(lián)系作者
? 您也可以直接添加小編微信(pwbot02)快速投稿,備注:姓名-投稿
△長按添加PaperWeekly小編
🔍
現(xiàn)在,在「知乎」也能找到我們了
進(jìn)入知乎首頁搜索「PaperWeekly」
點(diǎn)擊「關(guān)注」訂閱我們的專欄吧
·
總結(jié)
以上是生活随笔為你收集整理的别再无聊地吹捧了,一起来动手实现MAE玩玩吧!的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。