Pytorch采坑记录:DDP加载之前的checkpoint后loss上升(metric下降)
??最近在鼓搗使用pytorch的distributeddataparallel這個(gè)API搭一個(gè)數(shù)據(jù)并行的訓(xùn)練測(cè)試任務(wù),過程中遇到了一個(gè)問題,做一下記錄。
1、問題
??使用DDP打包了一個(gè)模型訓(xùn)練了一段時(shí)間,loss不斷下降metric不斷上升,一切都是很正常的現(xiàn)象。當(dāng)因?yàn)橐馔鈺和;蛘呤謩?dòng)暫停更改學(xué)習(xí)率而停止了程序,再開啟程序加載之前的checkpoint繼續(xù)訓(xùn)練,卻發(fā)現(xiàn)loss突然比之前上升或者metric比之前下降了很多。仔細(xì)看了一下loss的值,發(fā)現(xiàn)直接回到剛開始第一次訓(xùn)練模型時(shí)的水平,仿佛checkpoint根本沒加載進(jìn)去,是從初始化開始訓(xùn)練的一樣。
2、原因分析
根據(jù)我之前的框架使用經(jīng)驗(yàn),認(rèn)為可能的原因有以下兩點(diǎn):
2.1 模型的train和eval模式問題
??由于很多算子在訓(xùn)練模式和測(cè)試模式下的前向傳播原理不同,例如batchnorm和dropout等,導(dǎo)致幾乎所有的框架都會(huì)對(duì)模型設(shè)置一個(gè)train或eval的flag。Pytorch可以通過調(diào)用model.train()或model.eval()將模型的狀態(tài)進(jìn)行切換。在訓(xùn)練模式下如果模型是eval狀態(tài)或者在推理模式下模型是train狀態(tài)都會(huì)使得結(jié)果計(jì)算不正確,可能是導(dǎo)致上述問題的一個(gè)原因。
??但這個(gè)猜想很快就被我給否掉了。第一,按照我的經(jīng)驗(yàn),如果一個(gè)模型已經(jīng)訓(xùn)練到一個(gè)比較好的狀態(tài),即便是搞混了train和eval的狀態(tài)flag,結(jié)果雖然不對(duì)但是一般也不會(huì)差的特別多。我之前用0.1學(xué)習(xí)率訓(xùn)練了七八個(gè)epoch,損失已經(jīng)到了0.4~0.5左右。再次訓(xùn)練加載checkpoint損失直接飆到了差不多快到8了,這個(gè)跳躍太大了。第二,我去check了一下我的代碼,發(fā)現(xiàn)并沒有出現(xiàn)train和eval搞混的問題(手動(dòng)狗頭)。
2.2 模型沒有正確的加載進(jìn)去
??出現(xiàn)上述問題的另外一個(gè)可能的原因是:Pytorch沒有正確的將模型加載進(jìn)去。經(jīng)常使用pytorch的同學(xué)可能都遇到過這樣一種情況:自己設(shè)計(jì)了一個(gè)網(wǎng)絡(luò)用來做某項(xiàng)任務(wù),選擇了某個(gè)經(jīng)典分類模型(如resnet等)的特征提取部分作為backbone。訓(xùn)練時(shí)在github上下載了已經(jīng)在ImageNet數(shù)據(jù)集上pretrain的分類模型,并把這個(gè)模型的特征提取部分的權(quán)重直接加載到自己的模型中實(shí)現(xiàn)backbone預(yù)訓(xùn)練。但是效果卻并不好,可能的原因之一就是backbone并沒有成功的加載進(jìn)去。
??Pytorch中模型參數(shù)的保存底層使用的是字典的結(jié)構(gòu),因此參數(shù)加載需要保證參數(shù)名必須是一一對(duì)應(yīng)的。常用的一個(gè)加載模型參數(shù)的API是load_state_dict,其中有一個(gè)參數(shù)是strict=True,這個(gè)參數(shù)用來控制加載模型是否是“嚴(yán)格”的。嚴(yán)格指的是代碼模型定義里的所有parameter和buffer必須和要加載的checkpoint里的parameter和buffer的參數(shù)名、參數(shù)維度、參數(shù)類型等能夠一一對(duì)應(yīng)上,一個(gè)都不能多也不能少,否則就會(huì)報(bào)錯(cuò)。strict=False則可以允許代碼模型定義里的部分parameter或buffer和checkpoint中的對(duì)應(yīng)不上,如果有能對(duì)應(yīng)上的就加載,否則就忽略。比如下面的情況,當(dāng)strict=False時(shí),parameter2、3、4和5可以被正確加載,parameter1和6不會(huì)被加載而采用用戶定義的方式初始化;當(dāng)strict=True時(shí),加載會(huì)報(bào)錯(cuò)。
??在我遇到的問題中,經(jīng)過確認(rèn)我排除了這個(gè)可能性。checkpoint是用我自定義的模型訓(xùn)練得到的而不是從網(wǎng)上下載的,模型定義我沒有更改過因此和之前的是一樣的,而且我設(shè)置了strict=True,也沒有報(bào)錯(cuò)說明模型是被正確加載進(jìn)去的。
2.3 DistributedDataParallel問題
??以上兩種思考沒有解決我的問題,此時(shí)我痛定思痛仔細(xì)回想一下整個(gè)過程。同樣的代碼同樣的邏輯之前不做數(shù)據(jù)并行的時(shí)候是沒有問題的,但是一做DistributedDataParallel訓(xùn)練就出現(xiàn)了問題,說明bug出在DistributedDataParallel這里。看了一下這個(gè)API的源碼,找到了問題所在。在這個(gè)類的__init__函數(shù)里有這么一段:
class DistributedDataParallel(Module):def __init__(self, ...):...# Sync params and buffersself._sync_params_and_buffers(authoritative_rank=0)...也就是說在調(diào)用這個(gè)API把一個(gè)普通的model打包成一個(gè)ddp的model后,即實(shí)例化一個(gè)DistributedDataParallel對(duì)象的時(shí)候,就已經(jīng)完成了模型的parameter和buffer在主進(jìn)程模型和其他進(jìn)程上replica的同步。而我的代碼里,是先實(shí)例化了一個(gè)ddp對(duì)象,然后才去加載checkpoint 。
... model = MyModel() model.to(device=rank) model = nn.parallel.DistributedDataParallel(model, devices=[rank]) if rank == 0:ret = model.load_state_dict(torch.load(xxx), strict=True) ...此時(shí)代碼的執(zhí)行過程是:1、實(shí)例化一個(gè)MyModel對(duì)象并隨機(jī)初始化;2、實(shí)例化一個(gè)ddp對(duì)象并用之前隨機(jī)初始化的model去同步其他進(jìn)程上replica的parameter和buffer;3、將checkpoint的parameter和buffer加載到主進(jìn)程上的model中。此時(shí)其他幾個(gè)進(jìn)程上的model的parameter和buffer還都是隨機(jī)初始化的,在前向和反向傳播時(shí)雖然主進(jìn)程上的model給出了類似之前checkpoint比較準(zhǔn)確的結(jié)果。可是其他幾個(gè)子進(jìn)程上的模型由于參數(shù)是隨機(jī)初始化的所以結(jié)果差的很遠(yuǎn),各個(gè)進(jìn)程上的梯度經(jīng)過reduce_mean后就錯(cuò)的很離譜了。因此應(yīng)該調(diào)整一下代碼的順序?yàn)?#xff1a;
... model = MyModel() model.to(device=rank) if rank == 0:ret = model.load_state_dict(torch.load(xxx), strict=True) model = nn.parallel.DistributedDataParallel(model, devices=[rank]) ...??此時(shí)仍然有一個(gè)小小的bug,就是通過DistributedDataParallel這個(gè)API去打包模型后,模型的所有參數(shù)的名字都會(huì)多一個(gè)module的前綴,還是看一下API的源碼:
class DistributedDataParallel(Module):def __init__(self, module, ...):...self.module = module...熟悉Pytorch.nn.Module這個(gè)類的變量命名規(guī)則的同學(xué)應(yīng)該知道,加了這個(gè)成員變量賦值的語句后,所有模型變量的名字前綴都會(huì)多一個(gè)module。比如MyModel()實(shí)例化的對(duì)象中有一個(gè)名為conv1.weight的參數(shù),經(jīng)過DDP打包后得到的新模型中,對(duì)應(yīng)的參數(shù)變量名會(huì)變?yōu)閙odule.conv1.weight,一種解決辦法是可以通過保存模型時(shí)指定保存DDP對(duì)象的module模塊來消除這個(gè)前綴。
??水平有限,歡迎討論。
創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎(jiǎng)勵(lì)來咯,堅(jiān)持創(chuàng)作打卡瓜分現(xiàn)金大獎(jiǎng)總結(jié)
以上是生活随笔為你收集整理的Pytorch采坑记录:DDP加载之前的checkpoint后loss上升(metric下降)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 北京大学药学院张亮仁教授/刘振明研究员课
- 下一篇: pytorch安装问题:路径不对导致no