[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析
[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 論文簡(jiǎn)析及關(guān)鍵代碼簡(jiǎn)析
論文:https://arxiv.org/abs/2104.00323
代碼:https://github.com/dvlab-research/JigsawClustering
總結(jié)
本文提出了一種單批次(single-batch)的自監(jiān)督任務(wù)pretext task Jigsaw Cluster,相比于雙批次(dual-batches)的方法降低了計(jì)算量,同時(shí)利用了圖像內(nèi)的信息和圖像間的信息。
本文提出的任務(wù)構(gòu)造的主要流程如下如圖1所示,首先在一整個(gè)batch內(nèi)將 nnn 張圖像每張分為 m×mm\times mm×m 份圖塊,則共有 n×m×mn\times m\times mn×m×m 個(gè)圖塊。再將這些圖塊打亂(注意是一個(gè)batch內(nèi)所有的圖塊進(jìn)行打亂,而非某單張圖像內(nèi)打亂)后,再拼接為圖像。
本文設(shè)計(jì)的網(wǎng)絡(luò)(如圖2所示)在backbone提取特征之后有兩個(gè)分支:聚類分支和定位分支。聚類分支會(huì)完成一個(gè)有監(jiān)督聚類的任務(wù),將來自同一張?jiān)瓐D的不同圖塊(已被打亂)聚集到一簇(cluster,類)。作者使用了最近比較火的對(duì)比學(xué)習(xí)來完成這個(gè)有監(jiān)督聚類任務(wù)。而對(duì)于定位分支,則是要預(yù)測(cè)出圖塊在原圖中的位置,具體是由一個(gè)分類任務(wù)來完成,損失函數(shù)直接選用交叉熵?fù)p失。
算法細(xì)節(jié)如有重疊分塊、插值加池化等可見下面的原文翻譯。
源碼簡(jiǎn)析
以下是源碼中JigClu模型的關(guān)鍵幾步操作,筆者在進(jìn)行實(shí)驗(yàn)后將其中信號(hào)流的形狀等信息注釋在代碼中,希望能夠幫助大家理解,或者能夠?yàn)橄胍獜?fù)現(xiàn)并改進(jìn)本文的讀者提供一些參考。
@torch.no_grad()def _batch_gather_ddp(self, images): # images是長(zhǎng)度為4的列表,其中每個(gè)元素是形狀為 (n, 3, 112, 112)的tensor"""gather images from different gpus and shuffle between them*** Only support DistributedDataParallel (DDP) model. ***"""images_gather = []for i in range(4):batch_size_this = images[i].shape[0]images_gather.append(concat_all_gather(images[i]))batch_size_all = images_gather[i].shape[0]num_gpus = batch_size_all // batch_size_thisn,c,h,w = images_gather[0].shapepermute = torch.randperm(n*4).cuda()torch.distributed.broadcast(permute, src=0)images_gather = torch.cat(images_gather, dim=0)images_gather = images_gather[permute,:,:,:]col1 = torch.cat([images_gather[0:n], images_gather[n:2*n]], dim=3)col2 = torch.cat([images_gather[2*n:3*n], images_gather[3*n:]], dim=3)images_gather = torch.cat([col1, col2], dim=2)bs = images_gather.shape[0] // num_gpusgpu_idx = torch.distributed.get_rank()return images_gather[bs*gpu_idx:bs*(gpu_idx+1)], permute, ndef forward(self, images, progress):images_gather, permute, bs_all = self._batch_gather_ddp(images) # bs=16雙卡, len(images) 4, images_gather.shape (8, 3, 224, 224), permute.shape 64(即16*4), bs_all = 16# compute featuresq = self.encoder(images_gather) # bs=16雙卡, q.shape (8, 2048, 2, 2) q_gather = concat_all_gather(q) # bs=16雙卡, q_gather.shape (16, 2048, 2, 2) # 插值后池化,得到這個(gè)形狀n,c,h,w = q_gather.shapec1,c2 = q_gather.split([1,1],dim=2) # bs=16雙卡, c.shape (16, 2048, 1, 2)f1,f2 = c1.split([1,1],dim=3) # bs=16雙卡, f.shape (16, 2048, 1, 1)f3,f4 = c2.split([1,1],dim=3)q_gather = torch.cat([f1,f2,f3,f4],dim=0) # bs=16雙卡, q_gather.shape (64, 2048, 1, 1)q_gather = q_gather.view(n*4,-1) # bs=16雙卡, q_gather.shape (64, 2048)# clustering branchlabel_clu = permute % bs_all # permute: 0-(4*bs) 之間的隨機(jī)值, 取余則label_clu: 4組 0-bs之間的隨機(jī)值,即同一個(gè)值label_clu值是來自同一圖片的q_clu = self.encoder.fc_clu(q_gather) # bs=16雙卡,q_clu.shape (64, 128) 即(4*bs, dim)q_clu = nn.functional.normalize(q_clu, dim=1)loss_clu = self.criterion_clu(q_clu, label_clu)# location branchlabel_loc = torch.LongTensor([0]*bs_all+[1]*bs_all+[2]*bs_all+[3]*bs_all).cuda()label_loc = label_loc[permute]q_loc = self.encoder.fc_loc(q_gather)loss_loc = self.criterion_loc(q_loc, label_loc)return loss_clu, loss_loc筆者使用雙卡進(jìn)行實(shí)驗(yàn),batchsize設(shè)為16。
源碼中一些gather操作是為了適應(yīng)dp或者ddp訓(xùn)練,對(duì)理解算法本身沒有影響。
以下是筆者對(duì)原文部分進(jìn)行的翻譯,一些算法細(xì)節(jié)和實(shí)現(xiàn)細(xì)節(jié)可以從中找到,配合源碼注釋基本可以理解全文的算法思想。有疑惑或者異議歡迎留言討論。
原文部分翻譯
abstract
使用對(duì)比學(xué)習(xí)的無監(jiān)督表示學(xué)習(xí)取得了巨大的成功,該方法將每一訓(xùn)練批次復(fù)制來構(gòu)建對(duì)比對(duì),使每一訓(xùn)練批及其擴(kuò)增版本同時(shí)進(jìn)行前向傳播,導(dǎo)致額外計(jì)算。本文提出了一種新的jigsaw聚類 pretext task,該任務(wù)只需要將每個(gè)訓(xùn)練批次本身進(jìn)行前向傳播,并降低訓(xùn)練損失。我們的方法同時(shí)利用了圖像內(nèi)的和圖像間的信息,極大地超越了之前的基于單訓(xùn)練批次(single batch based)的方法。甚至得到了與使用對(duì)比訓(xùn)練的方法接近的結(jié)果,而相比之下本文方法只用了一半的訓(xùn)練批次。
我們的方法表明多批次訓(xùn)練是不必要的,并為未來的單批次無監(jiān)督的研究打開了大門
introduction
無監(jiān)督的視覺表示學(xué)習(xí),或者說自監(jiān)督學(xué)習(xí),是一個(gè)存在已久的問題,試圖在沒有人類監(jiān)督信號(hào)的情況下,得到一個(gè)通用特征提取器。這個(gè)目標(biāo)可以通過精心設(shè)計(jì)不帶有標(biāo)注的pretext task來訓(xùn)練特征提取器來達(dá)成。
根據(jù)pretext task的定義,大多數(shù)主流的方法分兩類:圖像內(nèi)(intra-image)的任務(wù)和圖像間(inter-image)的任務(wù)。圖像內(nèi)的任務(wù),包括colorization和jigsaw puzzle,設(shè)計(jì)一種一張圖像的變換,并訓(xùn)練一個(gè)網(wǎng)絡(luò)學(xué)習(xí)這種變換。由于每次只有訓(xùn)練批次本身需要前向傳播計(jì)算,所以我們將這些方法稱作單批次方法(single-batch methods)。這類任務(wù)只使用了一張圖片的信息就可以完成,這限制了特征提取器的學(xué)習(xí)能力。
最近幾年圖像間任務(wù)迅猛發(fā)展,要求網(wǎng)絡(luò)能夠辨別不同的圖像。對(duì)比學(xué)習(xí)現(xiàn)在很流行,因?yàn)樗梢越档驼龑?duì)的特征表示之間的距離,并擴(kuò)大負(fù)對(duì)的特征表示之間的距離。為了建構(gòu)正對(duì),訓(xùn)練過程需要使用經(jīng)過不同的數(shù)據(jù)擴(kuò)增的另一批次的數(shù)據(jù)。由于每個(gè)訓(xùn)練批次和它的擴(kuò)增過的版本要同時(shí)進(jìn)行前向傳播,我們將這些方法稱作雙批次方法(dual-batches methods)。這種方法在訓(xùn)練過程中大大提升了對(duì)資源的需求,如何能夠設(shè)計(jì)一種有效的基于單批次的方法,達(dá)到與雙批次相仿的性能仍舊是個(gè)問題。
本文中,我們提出了一個(gè)使用Jigsaw聚類(Jig-Clu)來有效訓(xùn)練無監(jiān)督模型的框架。該方法結(jié)合了拼圖和對(duì)比學(xué)習(xí)的優(yōu)點(diǎn),利用圖像內(nèi)部和圖像間的信息指導(dǎo)特征提取。它學(xué)習(xí)更全面的表達(dá)。
該方法在訓(xùn)練過程中只需要一個(gè)單批,但與其他單批方法相比,結(jié)果有很大提高。它甚至可以達(dá)到類似的結(jié)果與雙批次方法,但相比只有一半的訓(xùn)練批次。
jigsaw clustring task
在本文提出的JigClu任務(wù)中,同一批次內(nèi)的每張圖片被分成不同的塊,它們被隨機(jī)打亂在被接在一起,來形成一個(gè)新的批次用作訓(xùn)練。目標(biāo)就是將這個(gè)被打亂的恢復(fù)為原圖,如圖一所示。不同于以往的Jigsaw Puzzle任務(wù),原圖分成的塊是在整個(gè)批次內(nèi)被打亂的,而非在單張圖像內(nèi)。我們需要去預(yù)測(cè)的事每個(gè)塊屬于哪張圖片和每個(gè)塊在原圖中的位置。
我們使用蒙太奇(montage)圖像而非單個(gè)塊作為網(wǎng)絡(luò)的輸入。這個(gè)改動(dòng)大幅提升了任務(wù)的難度,并為網(wǎng)絡(luò)提供了更多的有用的信息供學(xué)習(xí)。網(wǎng)絡(luò)需要辨識(shí)出一張圖像的不同部分,并識(shí)別出它們?cè)瓉淼奈恢脧亩鴱亩嗝商?#xff08;multiple montage)輸入圖像中恢復(fù)原圖。
這個(gè)任務(wù)使得網(wǎng)絡(luò)能夠圖像內(nèi)和圖像間的信息,只需要通過對(duì)拼接后的圖像進(jìn)行前向傳播,與其他對(duì)比學(xué)習(xí)的任務(wù)相比只使用了一半的訓(xùn)練批次。
為了恢復(fù)來自交叉圖像的圖塊,我們?cè)O(shè)計(jì)了一個(gè)聚類分支和一個(gè)定位分支。如圖二所示,具體來說,我們首先將來自拼接圖像的全局特征圖解耦為每個(gè)圖塊的表示。然后這兩個(gè)分支對(duì)每個(gè)圖塊的特征表示進(jìn)行操作。聚類分支是將這些圖塊分為幾簇,每個(gè)簇只包含來自同一張圖像的圖塊。另一方面,定位分支,以圖像不可知的方式(image agnostic manner)預(yù)測(cè)每個(gè)圖塊的位置。
有了這兩個(gè)分支的預(yù)測(cè)結(jié)果,JigClu問題就得以解決。聚類分支作為一個(gè)有監(jiān)督聚類任務(wù)進(jìn)行訓(xùn)練,因?yàn)槲覀冎缊D塊是否來自同一張圖像。定位分支可以看作是一個(gè)分類任務(wù),其中每個(gè)圖塊會(huì)被分配一個(gè)標(biāo)簽,以此來表示其在原圖中的位置。定位分支預(yù)測(cè)所有圖塊的這個(gè)標(biāo)簽。
我們的方法得到了不錯(cuò)的結(jié)果,是因?yàn)槲覀兲岢龅娜蝿?wù)會(huì)使模型學(xué)習(xí)到不同種類的信息。一開始,從一張拼接的圖像中辨識(shí)出不同的圖塊迫使模型去捕捉圖像內(nèi)不實(shí)例級(jí)別(instance-level)的信息。這一級(jí)別的特征在其他的對(duì)比學(xué)習(xí)方法中是丟失了的。
進(jìn)一步,從多個(gè)輸入圖像中聚類到不同的圖塊有助于模型在圖像中學(xué)習(xí)圖像級(jí)別(image-level)的特征。這時(shí)最近的一些方法得到高質(zhì)量結(jié)果的關(guān)鍵。我們的方法保持了這一重要屬性。最后,將每個(gè)圖塊擺放到正確的位置又要求細(xì)節(jié)的定位信息,這時(shí)之前的單批次方法考慮到的。但是在最近的一些方法中被忽略了。我們認(rèn)為這種信息對(duì)于進(jìn)一步提升結(jié)果來說仍舊是重要的。
performance of our method
通過我們的方法進(jìn)行學(xué)習(xí),可以產(chǎn)生圖像內(nèi)的和圖像間的信息。這樣綜合的學(xué)習(xí)可以帶來一些優(yōu)勢(shì)(spectrum of superiority)。首先,我們的方法在訓(xùn)練階段只有一個(gè)批次,在Imagenet-1k的線性評(píng)估階段比其他單批次方法高了2.6%。 。。。
related work
handcrafted pretext tasks
訓(xùn)練無監(jiān)督模型的pretext task的方法有很多種。 將破壞過的圖像進(jìn)行恢復(fù)是一個(gè)重要主題,有with tasks of descriminating synthetic artifacts [18], colorization [20, 43], image inpainting [31], and denoising auto-encoders [37], 等。另外,許多方法通過一些變換生成persuade labels(?)來訓(xùn)練網(wǎng)絡(luò)。應(yīng)用包括預(yù)測(cè)兩個(gè)塊的關(guān)系,解決jigsaw puzzle,還有識(shí)別被替代的類。[]是一個(gè)進(jìn)階版的jigsaw puzzle,利用更復(fù)雜的方法選擇圖塊。視頻信息在訓(xùn)練無監(jiān)督模型時(shí)也很常用。
contrastive learning
我們的方法和對(duì)比學(xué)習(xí)也高度相關(guān),首先由[]提出,根據(jù)[]可以得到更好的性能。最近[],使用不同的擴(kuò)增方法構(gòu)建對(duì)比對(duì)取得了巨大的成功。尤其是,[]在pixel水平上利用圖像間和圖像內(nèi)的信息。我們注意到訓(xùn)練多批次圖像的對(duì)比學(xué)習(xí)方法需要大量的訓(xùn)練資源。通過新穎的在單批次內(nèi)設(shè)計(jì)對(duì)比對(duì),我們的工作解決了這個(gè)問題。
jigsaw clustering
本章,我們會(huì)給出本文所提出的任務(wù)的定義。我們使用一個(gè)很簡(jiǎn)單的網(wǎng)絡(luò),只需要對(duì)原始的骨干網(wǎng)絡(luò)進(jìn)行一點(diǎn)點(diǎn)調(diào)整。最后,我們?cè)O(shè)計(jì)了一個(gè)新穎的損失函數(shù)來更好地適應(yīng)我們的聚類任務(wù)。
the jigsaw clustering task
在一個(gè)批次 X\bf{X}X=x1,x2,…,xn=x_1,x_2,\dots,x_n=x1?,x2?,…,xn? 內(nèi),有 nnn 個(gè)隨機(jī)選擇的圖像。每張圖像 xix_ixi? 被分為 m×mm\times mm×m 個(gè)圖塊。共有 n×m×mn\times m\times mn×m×m 個(gè)圖塊。所有這些圖塊會(huì)被隨機(jī)重新排列來形成一組有蒙太奇圖像X′\bf{X'}X′=x1′,x2′,…,xn′=x'_1,x'_2,\dots,x'_n=x1′?,x2′?,…,xn′? 形成的新的批次。每張新圖同樣包含 m×mm\times mm×m 個(gè)圖塊,這些圖塊來自不同的原批次 X\bf{X}X 中的圖像。
任務(wù)就是對(duì)新批次 X\bf{X}X 中的這 n×m×mn\times m\times mn×m×m 個(gè)圖塊進(jìn)行聚類為 nnn 個(gè)簇,并且對(duì)同一簇的 $ m\times m$ 個(gè)圖塊預(yù)測(cè)位置來恢復(fù)出 nnn 張?jiān)瓐D,整個(gè)過程見圖1。
本文提出的任務(wù)的關(guān)鍵是使用蒙太奇圖像作為輸入而不是每單獨(dú)一個(gè)圖塊。值得注意的是,直接使用小圖塊作為輸入會(huì)導(dǎo)致solution只有全局信息。此外,小尺寸的輸入圖像在許多應(yīng)用中并不常見。僅在此處使用它們會(huì)引發(fā)pretext task和其他下游任務(wù)之間的圖像分辨率差異問題。這也可能導(dǎo)致性能下降。而簡(jiǎn)單地直接擴(kuò)展小圖塊將極大地提升訓(xùn)練資源。
我們將蒙太奇圖像作為輸入完美地避免了這些問題。首先,來自一個(gè)批次的輸入圖像與原批次有著相同的尺寸,這和最近的方法相比只消耗了一半的資源。更重要的是,為了更好地完成本任務(wù),網(wǎng)絡(luò)需要學(xué)習(xí)細(xì)節(jié)的圖像內(nèi)的特征,來辨別一張圖像中的不同圖塊,和全局的圖像間的特征來將來自同一張?jiān)瓐D的不同圖塊聚集在一起。我們觀察到全面特征的學(xué)習(xí)大幅加速了特征提取了的訓(xùn)練。更多實(shí)驗(yàn)結(jié)果見下一節(jié)。
在本方法中,分圖像的方法是很關(guān)鍵的。mmm 的選擇影響到任務(wù)的難度。我們的在ImageNet子集上的消融實(shí)驗(yàn)顯示 m=2m=2m=2 時(shí)得到最好的結(jié)果。我們推測(cè) mmm 過大會(huì)呈指數(shù)級(jí)地增加復(fù)雜度,使得網(wǎng)絡(luò)不能高效地學(xué)習(xí)。另外,我們觀察到將圖像切割為不連接的圖塊(disjoint pathches)并不是最優(yōu)的。如圖3所示,隨著交叉點(diǎn)的延伸,網(wǎng)絡(luò)學(xué)習(xí)到更好的特征。這時(shí)可以解釋的,因?yàn)槟承﹫D像的不同區(qū)域過于多樣化。如果沒有任何重疊的跡象,它們會(huì)給學(xué)習(xí)帶來困難。第5節(jié)會(huì)有更多解釋。
network design
我們?yōu)楸救蝿?wù)設(shè)計(jì)了一個(gè)新的解耦網(wǎng)絡(luò)。首先是特征提取器,可以是任何網(wǎng)絡(luò)[]。然后有一個(gè)無參數(shù)的解耦網(wǎng)絡(luò)來將特征分為 m×mm\times mm×m 個(gè)部分,對(duì)應(yīng)同一個(gè)輸入圖像的不同的塊。然后用一個(gè)MLP來嵌入每個(gè)塊的特征,用作聚類任務(wù);一個(gè)全連接層用來做定位任務(wù)。
解耦模塊首先將主干的特征映射插值為邊長(zhǎng)為 mmm 的倍數(shù)的新特征映射。我們是擴(kuò)大特征圖而非縮小從而避免信息丟失。舉個(gè)例子,比如ImageNet,輸入尺寸都是224x224.如果用ResNet-50作骨干網(wǎng)絡(luò),則提取到的特征是空間尺寸是 7x7的。如果 m=2m=2m=2 ,我們就將特征圖用雙線性插值搭配8x8。這樣特征圖的長(zhǎng)度就是 mmm 的倍數(shù),我們可以使用平均池化,來對(duì)特征圖進(jìn)行降采樣到 n×m×m×c^n\times m\times m\times \hat{c}n×m×m×c^ 。這樣,一個(gè)batch的就被分解為 (n×m×m)×c^(n\times m\times m)\times \hat{c}(n×m×m)×c^ ,即有 (n×m×m)(n\times m\times m)(n×m×m) 個(gè)維度為 c^\hat{c}c^ 的向量。
然后每個(gè)向量都經(jīng)過兩層MLP嵌入到長(zhǎng)度為 ccc ,來形成一組向量 Z=z1,z2,…,znmm\mathbf{Z}=z_1,z_2,\dots,z_{nmm}Z=z1?,z2?,…,znmm? 用作聚類任務(wù)。同時(shí), (n×m×m)×c^(n\times m\times m)\times \hat{c}(n×m×m)×c^ 的向量還會(huì)被送到一個(gè)作為分類器的全連接層,產(chǎn)生logits L=l1,l2,…,lnmm\mathbf{L}=l_1,l_2,\dots,l_{nmm}L=l1?,l2?,…,lnmm?,來完成定位任務(wù)。
我們的網(wǎng)絡(luò)是相當(dāng)高效的,這個(gè)額外的解耦模塊是不需要參數(shù)的。與近期的工作相比,取一批的計(jì)算方法基本相同,訓(xùn)練時(shí)只需取一批。這大大降低了訓(xùn)練成本。
loss functions
聚類分支是一個(gè)有監(jiān)督聚類任務(wù),因?yàn)?m×mm\times mm×m 個(gè)塊來自同一類。有監(jiān)督聚類任務(wù)很方便,我們使用對(duì)比學(xué)習(xí)來實(shí)現(xiàn)。我們將聚類的目標(biāo)是將來自同一類的物體(塊)拉到一起,將來自不同類的圖塊推開。我們使用余弦相似度來測(cè)量塊之間的距離。這樣來自同一簇的每一對(duì)塊,損失函數(shù)如下:
?i,j=?logexp(cos(zi,zj)/τ)∑k=1nmm1k≠iexp(cos(zi,zj)/τ)\ell_{i,j}=-log\frac{exp(cos(z_i,z_j)/\tau)}{\sum_{k=1}^{nmm}\mathbb{1}_{k\neq i}exp(cos(z_i,z_j)/\tau)} ?i,j?=?log∑k=1nmm?1k?=i?exp(cos(zi?,zj?)/τ)exp(cos(zi?,zj?)/τ)?
其中 1\mathbb{1}1 表示指示函數(shù)(indicator function),τ\tauτ 是溫度系數(shù),用來平滑或者加劇距離。最終的所有來自同一簇的圖塊對(duì)的損失函數(shù)可寫作:
Lclu=1nmm∑i(1mm?1∑j∈Ci?i,j)\mathcal{L}_{clu}=\frac{1}{nmm}\sum_i(\frac{1}{mm-1}\sum_{j\in C_i\ell_{i,j}}) Lclu?=nmm1?i∑?(mm?11?j∈Ci??i,j?∑?)
其中 CiC_iCi? 表示同一簇 iii 內(nèi)的圖塊的索引 。
定位分支被視作是一個(gè)分類任務(wù),損失函數(shù)是簡(jiǎn)單的交叉熵?fù)p失,寫作:
Lloc=CrossEntropy(L,Lgt)\mathcal{L}_{loc}=CrossEntropy(\mathbf{L,L_{gt}}) Lloc?=CrossEntropy(L,Lgt?)
我們提出的Ji個(gè)C路的總體損失則為:
L=αLclu+βLloc\mathcal{L}=\alpha\mathcal{L}_{clu}+\beta\mathcal{L}_{loc} L=αLclu?+βLloc?
在我們的實(shí)驗(yàn)中,α=β=1\alpha=\beta=1α=β=1 即可得到好的結(jié)果。
總結(jié)
以上是生活随笔為你收集整理的[2021-CVPR] Jigsaw Clustering for Unsupervised Visual Representation Learning 论文简析及关键代码简析的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 反思 大班 快乐的机器人_幼儿园大班教案
- 下一篇: 二战时德国陆军最有名的轻重武器