关于炼丹,你是否知道这些细节?
作者丨Fatescript
來源丨h(huán)ttps://zhuanlan.zhihu.com/p/450779978
編輯丨GiantPandaCV
序
本文算是我工作一年多以來的一些想法和經(jīng)驗(yàn),最早發(fā)布在曠視研究院內(nèi)部的論壇中,本著開放和分享的精神發(fā)布在我的知乎專欄中,如果想看干貨的話可以直接跳過動(dòng)機(jī)部分。另外,后續(xù)在這個(gè)專欄中,我會(huì)做一些關(guān)于原理和設(shè)計(jì)方面的一些分享,希望能給領(lǐng)域從業(yè)人員提供一些看待問題的不一樣的視角。
動(dòng)機(jī)
前段時(shí)間走在路上,一直在思考一個(gè)問題:我的時(shí)間開銷很多都被拿去給別人解釋一些在我看起來顯而易見的問題了,比如( https://link.zhihu.com/?target=https%3A//github.com/Megvii- BaseDetection/cvpods )里面的一些code寫法問題(雖然這在某些方面說明了文檔建設(shè)的不完善),而這變相導(dǎo)致了我實(shí)際工作時(shí)間的減少,如何讓別人少問一些我覺得答案顯而易見的問題?如何讓別人提前規(guī)避一些不必要的坑?只有解決掉這樣的一些問題,我才能從一件件繁瑣的小事中解放出來,把精力放在我真正關(guān)心的事情上去。
其實(shí)之前同事有跟我說過類似的話,每次帶一個(gè)新人,都要告訴他:你的實(shí)現(xiàn)需要注意這里blabla,還要注意那里blabla。說實(shí)話,我很佩服那些帶intern時(shí)候非常細(xì)致和知無不言的人,但我本性上并不喜歡每次花費(fèi)時(shí)間去解釋一些我覺得顯而易見的問題,所以我寫下了這個(gè)帖子,把我踩過的坑和留下來的經(jīng)驗(yàn)分享出去。希望能夠方便別人,同時(shí)也節(jié)約我的時(shí)間。
加入曠視以來,個(gè)人一直在做一些關(guān)于框架相關(guān)的內(nèi)容,所以內(nèi)容主要偏向于模型訓(xùn)練之類的工作。因?yàn)?一個(gè)擁有知識(shí)的人是無法想象知識(shí)在別人腦海中的樣子的(the curse of knowledge),所以我只能選取被問的最多的,和我認(rèn)為最應(yīng)該知道的 。
準(zhǔn)備好了的話,我們就啟航出發(fā)(另,這篇專欄文章會(huì)長期進(jìn)行更新)。
坑/經(jīng)驗(yàn)
Data模塊
python圖像處理用的最多的兩個(gè)庫是opencv和Pillow(PIL),但是兩者讀取出來的圖像并不一樣, opencv讀取的圖像格式的三個(gè)通道是BGR形式的,但是PIL是RGB格式的 。這個(gè)問題看起來很小,但是衍生出來的坑可以有很多,最常見的場景就是數(shù)據(jù)增強(qiáng)和預(yù)訓(xùn)練模型中。比如有些數(shù)據(jù)增強(qiáng)的方法是基于channel維度的,比如megengine里面的HueTransform,這一行代碼 (https://github.com/MegEngine/MegEngine/blob/4d72e7071d6b8f8240edc56c6853384850b7407f/imperative/python/megengine/data/transform/vision/transform.py#L958 ) 顯然是需要確保圖像是BGR的,但是經(jīng)常會(huì)有人只看有Transform就無腦用了,從來沒有考慮過這些問題。
接上條,RGB和BGR的另一個(gè)問題就是導(dǎo)致預(yù)訓(xùn)練模型載入后訓(xùn)練的方式不對(duì),最常見的場景就是預(yù)訓(xùn)練模型的input channel是RGB的(例如torch官方來的預(yù)訓(xùn)練模型),然后你用cv2做數(shù)據(jù)處理,最后還忘了convert成RGB的格式,那么就是會(huì)有問題。這個(gè)問題應(yīng)該很多煉丹的同學(xué)沒有注意過,我之前寫CenterNet-better(https://github.com/FateScript/CenterNet-better)就發(fā)現(xiàn)CenterNet(https://github.com/xingyizhou/CenterNet)存在這么一個(gè)問題,要知道當(dāng)時(shí)這可是一個(gè)有著3k多star的倉庫,但是從來沒有人意識(shí)到有這個(gè)問題。當(dāng)然,依照我的經(jīng)驗(yàn),如果你訓(xùn)練的iter足夠多,即使你的channel有問題,對(duì)于結(jié)果的影響也會(huì)非常小。不過,既然能做對(duì),為啥不注意這些問題一次性做對(duì)呢?
torchvision中提供的模型,都是輸入圖像經(jīng)過了ToTensor操作train出來的。也就是說最后在進(jìn)入網(wǎng)絡(luò)之前會(huì)統(tǒng)一除以255從而將網(wǎng)絡(luò)的輸入變到0到1之間。torchvision的文檔(https://pytorch.org/vision/stable/models.html)給出了他們使用的mean和std,也是0-1的mean和std。如果你使用torch預(yù)訓(xùn)練的模型,但是輸入還是0-255的,那么恭喜你,在載入模型上你又會(huì)踩一個(gè)大坑(要么你的圖像先除以255,要么你的code中mean和std的數(shù)值都要乘以255)。
ToTensor之后接數(shù)據(jù)處理的坑。上一條說了ToTensor之后圖像變成了0到1的,但是一些數(shù)據(jù)增強(qiáng)對(duì)數(shù)值做處理的時(shí)候,是針對(duì)標(biāo)準(zhǔn)圖像,很多人ToTensor之后接了這樣一個(gè)數(shù)據(jù)增強(qiáng),最后就是練出來的丹是廢的(心疼電費(fèi)QaQ)。
數(shù)據(jù)集里面有一個(gè)圖特別詭異,只要train到那一張圖就會(huì)炸顯存(CUDA OOM),別的圖訓(xùn)練起來都沒有問題,應(yīng)該怎么處理?通常出現(xiàn)這個(gè)問題,首先判斷數(shù)據(jù)本身是不是有問題。如果數(shù)據(jù)本身有問題,在一開始生成Dataset對(duì)象的時(shí)候去掉就行了。如果數(shù)據(jù)本身沒有問題,只不過因?yàn)橐恍┨厥庠驅(qū)е嘛@存炸了(比如檢測中圖像的GT boxes過多的問題),可以catch一個(gè)CUDA OOM的error之后將一些邏輯放在CPU上,最后retry一下,這樣只是會(huì)慢一個(gè)iter,但是訓(xùn)練過程還是可以完整走完的,在我們開源的YOLOX里有類似的參考code(https://github.com/Megvii-BaseDetection/YOLOX/blob/0.1.0/yolox/models/yolo_head.py#L330-L334)。
pytorch中dataloader的坑。有時(shí)候會(huì)遇到pytorch num_workers=0(也就是單進(jìn)程)沒有問題,但是多進(jìn)程就會(huì)報(bào)一些看不懂的錯(cuò)的現(xiàn)象,這種情況通常是因?yàn)閠orch到了ulimit的上限,更核心的原因是 torch的dataloader不會(huì)釋放文件描述符 (參考issue: https://github.com/pytorch/pytorch/issues/973)。可以u(píng)limit -n 看一下機(jī)器的設(shè)置。跑程序之前修改一下對(duì)應(yīng)的數(shù)值。
opencv和dataloader的神奇聯(lián)動(dòng)。很多人經(jīng)常來問為啥要寫cv2.setNumThreads(0),其實(shí)是因?yàn)閏v2在做resize等op的時(shí)候會(huì)用多線程,當(dāng)torch的dataloader是多進(jìn)程的時(shí)候,多進(jìn)程套多線程,很容易就卡死了(具體哪里死鎖了我沒探究很深)。除了setNumThreads之外,通常還要加一句cv2.ocl.setUseOpenCL(False),原因是cv2使用opencl和cuda一起用的時(shí)候通常會(huì)拖慢速度,加了萬事大吉,說不定還能加速。感謝評(píng)論區(qū) @Yuxin Wu(https://www.zhihu.com/people/ppwwyyxx)?大大的指正
dataloader會(huì)在epoch結(jié)束之后進(jìn)行類似重新加載的操作,復(fù)現(xiàn)這個(gè)問題的code稍微有些長,放在后面了。這個(gè)問題算是可以說是一個(gè)高級(jí)bug/feature了,可能導(dǎo)致的問題之一就是煉丹師在本地的code上進(jìn)行了一些修改,然后訓(xùn)練過程直接加載進(jìn)去了。解決方法也很簡單,讓你的sampler源源不斷地產(chǎn)生數(shù)據(jù)就好,這樣即使本地code有修改也不會(huì)加載進(jìn)去。
Module模塊
BatchNorm在訓(xùn)練和推斷的時(shí)候的行為是不一致的。這也是新人最常見的錯(cuò)誤(類似的算子還有dropout,這里提一嘴, pytorch的dropout在eval的時(shí)候行為是Identity ,之前有遇到過實(shí)習(xí)生說dropout加了沒效果,直到我看了他的code:x = F.dropout(x, p=0.5)
BatchNorm疊加分布式訓(xùn)練的坑。在使用DDP(DistributedDataParallel)進(jìn)行訓(xùn)練的時(shí)候,每張卡上的BN統(tǒng)計(jì)量是可能不一樣的,仔細(xì)檢查broadcast_buffer這個(gè)參數(shù) 。DDP的默認(rèn)行為是在forward之前將rank0 的 buffer做一次broadcast(broadcast_buffer=True),但是一些常用的開源檢測倉庫是將broadcast_buffer設(shè)置成False的(參考:mmdet(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206)?和 detectron2(https://github.com/facebookresearch/detectron2/blob/f50ec07cf220982e2c4861c5a9a17c4864ab5bfd/tools/plain_train_net.py#L206),我猜是在檢測任務(wù)中因?yàn)閎atchsize過小,統(tǒng)一用卡0的統(tǒng)計(jì)量會(huì)掉點(diǎn)) 這個(gè)問題在一邊訓(xùn)練一邊測試的code中更常見 ,比如說你train了5個(gè)epoch,然后要分布式測試一下。一般的邏輯是將數(shù)據(jù)集分到每塊卡上,每塊卡進(jìn)行inference,最后gather到卡0上進(jìn)行測點(diǎn)。但是 因?yàn)槊繌埧ńy(tǒng)計(jì)量是不一樣的,所以和那種把卡0的模型broadcast到不同卡上測試出來的結(jié)果是不一樣的。這也是為啥通常訓(xùn)練完測的點(diǎn)和單獨(dú)起了一個(gè)測試腳本跑出來的點(diǎn)不一樣的原因 (當(dāng)然你用SyncBN就不會(huì)有這個(gè)問題)。
Pytorch的SyncBN在1.5之前一直實(shí)現(xiàn)的有bug,所以有一些老倉庫是存在使用SyncBN之后掉點(diǎn)的問題的。
用了多卡開多尺度訓(xùn)練,明明尺度更小了,但是速度好像不是很理想?這個(gè)問題涉及到多卡的原理,因?yàn)榉植际接?xùn)練的時(shí)候,在得到新的參數(shù)之后往往需要進(jìn)行一次同步。假設(shè)有兩張卡,卡0的尺度非常小,卡1的尺度非常大,那么就會(huì)出現(xiàn)卡0始終在等卡1,于是就出現(xiàn)了雖然有的尺度變小了,但是整體的訓(xùn)練速度并沒有變快的現(xiàn)象(木桶效應(yīng))。解決這個(gè)問題的思路就是 盡量把負(fù)載拉均衡一些 。
多卡的小batch模擬大batch(梯度累積)的坑。假設(shè)我們?cè)趩慰ㄏ轮荒苋耣atchsize = 2,那么為了模擬一個(gè)batchsize = 8的效果,通常的做法是forward / backward 4次,不清理梯度,step一次(當(dāng)然考慮BN的統(tǒng)計(jì)量問題這種做法和單純的batchsize=8肯定還是有一些差別的)。在多卡下,因?yàn)檎{(diào)用loss.backward的時(shí)候會(huì)做grad的同步,所以說前三次調(diào)用backward的時(shí)候需要加ddp.no_sync(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=no_sync#torch.nn.parallel.DistributedDataParallel.no_sync)的context manager(不加的話,第一次bp之后,各個(gè)卡上的grad此時(shí)會(huì)進(jìn)行同步),最后一次則不需要加。當(dāng)然,我看很多倉庫并沒有這么做,我只能理解他們就是單純想做梯度累積(BTW,加了ddp.no_sync會(huì)使得程序快一些,畢竟加了之后bp過程是無通訊的)。
浮點(diǎn)數(shù)的加法其實(shí)不遵守交換律的 ,這個(gè)通常能衍生出來GPU上的運(yùn)算結(jié)果不能嚴(yán)格復(fù)現(xiàn)的現(xiàn)象。可能一些非計(jì)算機(jī)軟件專業(yè)的同學(xué)并不理解這一件事情,直接自己開一個(gè)python終端體驗(yàn)可能會(huì)更好:
訓(xùn)練模塊
FP16訓(xùn)練/混合精度訓(xùn)練。使用Apex訓(xùn)練混合精度模型,在保存checkpoint用于繼續(xù)訓(xùn)練的時(shí)候,除了model和optimizer本身的state_dict之外,還需要保存一下amp的state_dict,這個(gè)在amp的文檔(https://link.zhihu.com/?target=https%3A//nvidia.github.io/apex/amp.html%23checkpointing)中也有提過。(當(dāng)然,經(jīng)驗(yàn)上來說忘了保存影響不大,會(huì)多花幾個(gè)iter search一個(gè)loss scalar出來)
多機(jī)分布式訓(xùn)練卡死的問題。好友 @NoahSYZhang(https://www.zhihu.com/people/syzhangbuaa) 遇到的一個(gè)坑。場景是先申請(qǐng)了兩個(gè)8卡機(jī),然后機(jī)器1和機(jī)器2用前4塊卡做通訊(local rank最大都是4,總共是兩機(jī)8卡)。可以初始化process group,但是在使用DDP的時(shí)候會(huì)卡死。原因在于pytorch在做DDP的時(shí)候會(huì)猜測一個(gè)rank,參考code(https://github.com/pytorch/pytorch/blob/0d437fe6d0ef17648072eb586484a4a5a080b094/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1622-L1630)。對(duì)于上面的場景,第二個(gè)機(jī)器上因?yàn)榇嬖诳?到卡8,而對(duì)應(yīng)的rank也是5到8,所以DDP就會(huì)認(rèn)為自己需要同步的是卡5到卡8,于是就卡死了。
復(fù)現(xiàn)Code
Data部分
from?torch.utils.data?import?DataLoader from?torch.utils.data?import?Dataset import?tqdm import?timeclass?SimpleDataset(Dataset):def?__init__(self,?length=400):self.length?=?lengthself.data_list?=?list(range(length))def?__getitem__(self,?index):data?=?self.data_list[index]time.sleep(0.1)return?datadef?__len__(self):return?self.lengthdef?train(local_rank):dataset?=?SimpleDataset()dataloader?=?DataLoader(dataset,?batch_size=1,?num_workers=2)iter_loader?=?iter(dataloader)max_iter?=?100000for?_?in?tqdm.tqdm(range(max_iter)):try:_?=?next(iter_loader)except?StopIteration:print("Refresh?here?!!!!!!!!")iter_loader?=?iter(dataloader)_?=?next(iter_loader)if?__name__?==?"__main__":import?torch.multiprocessing?as?mpmp.spawn(train,?args=(),?nprocs=2,?daemon=False)當(dāng)程序運(yùn)行起來的時(shí)候,可以在Dataset里面的__getitem__方法里面加一個(gè)print輸出一些內(nèi)容,在refresh之后,就會(huì)print對(duì)應(yīng)的內(nèi)容哦(看到現(xiàn)象是不是覺得自己以前煉的丹可能有問題了呢hhh)
一些碎碎念
一口氣寫了這么多條也有點(diǎn)累了,后續(xù)有踩到新坑的話我也會(huì)繼續(xù)更新這篇文章的。畢竟寫這篇文章是希望工作中不再會(huì)有人踩類似的坑 & 煉丹的人能夠?qū)ι疃葘W(xué)習(xí)框架有意識(shí)(雖然某種程度上來講這算是個(gè)心智負(fù)擔(dān))。
如果說今年來什么事情是最大的收獲的話,那就是理解了一個(gè)開放的生態(tài)是可以迸發(fā)出極強(qiáng)的活力的,也希望能看到更多的人來分享自己遇到的問題和解決的思路。畢竟探索的答案只是一個(gè)副產(chǎn)品,過程本身才是最大的財(cái)寶。
本文僅做學(xué)術(shù)分享,如有侵權(quán),請(qǐng)聯(lián)系刪文。
重磅!計(jì)算機(jī)視覺工坊-學(xué)習(xí)交流群已成立
掃碼添加小助手微信,可申請(qǐng)加入3D視覺工坊-學(xué)術(shù)論文寫作與投稿?微信交流群,旨在交流頂會(huì)、頂刊、SCI、EI等寫作與投稿事宜。
同時(shí)也可申請(qǐng)加入我們的細(xì)分方向交流群,目前主要有ORB-SLAM系列源碼學(xué)習(xí)、3D視覺、CV&深度學(xué)習(xí)、SLAM、三維重建、點(diǎn)云后處理、自動(dòng)駕駛、CV入門、三維測量、VR/AR、3D人臉識(shí)別、醫(yī)療影像、缺陷檢測、行人重識(shí)別、目標(biāo)跟蹤、視覺產(chǎn)品落地、視覺競賽、車牌識(shí)別、硬件選型、深度估計(jì)、學(xué)術(shù)交流、求職交流等微信群,請(qǐng)掃描下面微信號(hào)加群,備注:”研究方向+學(xué)校/公司+昵稱“,例如:”3D視覺?+ 上海交大 + 靜靜“。請(qǐng)按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)去相關(guān)微信群。原創(chuàng)投稿也請(qǐng)聯(lián)系。
▲長按加微信群或投稿
▲長按關(guān)注公眾號(hào)
3D視覺從入門到精通知識(shí)星球:針對(duì)3D視覺領(lǐng)域的視頻課程(三維重建系列、三維點(diǎn)云系列、結(jié)構(gòu)光系列、手眼標(biāo)定、相機(jī)標(biāo)定、激光/視覺SLAM、自動(dòng)駕駛等)、知識(shí)點(diǎn)匯總、入門進(jìn)階學(xué)習(xí)路線、最新paper分享、疑問解答五個(gè)方面進(jìn)行深耕,更有各類大廠的算法工程人員進(jìn)行技術(shù)指導(dǎo)。與此同時(shí),星球?qū)⒙?lián)合知名企業(yè)發(fā)布3D視覺相關(guān)算法開發(fā)崗位以及項(xiàng)目對(duì)接信息,打造成集技術(shù)與就業(yè)為一體的鐵桿粉絲聚集區(qū),近4000星球成員為創(chuàng)造更好的AI世界共同進(jìn)步,知識(shí)星球入口:
學(xué)習(xí)3D視覺核心技術(shù),掃描查看介紹,3天內(nèi)無條件退款
?圈里有高質(zhì)量教程資料、可答疑解惑、助你高效解決問題
覺得有用,麻煩給個(gè)贊和在看
總結(jié)
以上是生活随笔為你收集整理的关于炼丹,你是否知道这些细节?的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 外贸收款(解析重点)——上海赢支付win
- 下一篇: 【Crypto】BUGKU-抄错的字符