更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...
?在pytorch中停止梯度流的若干辦法,避免不必要模塊的參數(shù)更新
2020/4/11 FesianXu
前言
在現(xiàn)在的深度模型軟件框架中,如TensorFlow和PyTorch等等,都是實(shí)現(xiàn)了自動(dòng)求導(dǎo)機(jī)制的。在深度學(xué)習(xí)中,有時(shí)候我們需要對(duì)某些模塊的梯度流進(jìn)行精確地控制,包括是否允許某個(gè)模塊的參數(shù)更新,更新地幅度多少,是否每個(gè)模塊更新地幅度都是一樣的。這些問(wèn)題非常常見(jiàn),但是在實(shí)踐中卻很容易出錯(cuò),我們?cè)谶@篇文章中嘗試對(duì)第一個(gè)子問(wèn)題,也就是如果精確控制某些模型是否允許其參數(shù)更新,進(jìn)行總結(jié)。如有謬誤,請(qǐng)聯(lián)系指出,轉(zhuǎn)載請(qǐng)注明出處。
本文實(shí)驗(yàn)平臺(tái):pytorch 1.4.0, ubuntu 18.04, python 3.6
聯(lián)系方式:
e-mail: FesianXu@gmail.com
QQ: 973926198
github: https://github.com/FesianXu
知乎專(zhuān)欄: 計(jì)算機(jī)視覺(jué)/計(jì)算機(jī)圖形理論與應(yīng)用
微信公眾號(hào):
為什么我們要控制梯度流
為什么我們要控制梯度流?這個(gè)答案有很多個(gè),但是都可以歸結(jié)為避免不需要更新的模型模塊被參數(shù)更新。我們?cè)谏疃饶P陀?xùn)練過(guò)程中,很可能存在多個(gè)loss,比如GAN對(duì)抗生成網(wǎng)絡(luò),存在G_loss和D_loss,通常來(lái)說(shuō),我們通過(guò)D_loss只希望更新判別器(Discriminator),而生成網(wǎng)絡(luò)(Generator)并不需要,也不能被更新;生成網(wǎng)絡(luò)只在通過(guò)G_loss學(xué)習(xí)的情況下,才能被更新。這個(gè)時(shí)候,如果我們不控制梯度流,那么我們?cè)谟?xùn)練D_loss的時(shí)候,我們的前端網(wǎng)絡(luò)Generator和CNN難免也會(huì)被一起訓(xùn)練,這個(gè)是我們不期望發(fā)生的。
Fig 1.1 典型的GAN結(jié)構(gòu),由生成器和判別器組成。
多個(gè)loss的協(xié)調(diào)只是其中一種情況,還有一種情況是:我們?cè)谶M(jìn)行模型遷移的過(guò)程中,經(jīng)常采用某些已經(jīng)預(yù)訓(xùn)練好了的特征提取網(wǎng)絡(luò),比如VGG, ResNet之類(lèi)的,在適用到具體的業(yè)務(wù)數(shù)據(jù)集時(shí)候,特別是小數(shù)據(jù)集的時(shí)候,我們可能會(huì)希望這些前端的特征提取器不要更新,而只是更新末端的分類(lèi)器(因?yàn)閿?shù)據(jù)集很小的情況下,如果貿(mào)然更新特征提取器,很可能出現(xiàn)不期望的嚴(yán)重過(guò)擬合,這個(gè)時(shí)候的合適做法應(yīng)該是更新分類(lèi)器優(yōu)先),這個(gè)時(shí)候我們也可以考慮停止特征提取器的梯度流。
這些情況還有很多,我們?cè)趯?shí)踐中發(fā)現(xiàn),精確控制某些模塊的梯度流是非常重要的。筆者在本文中打算討論的是對(duì)某些模塊的梯度流的截?cái)?/strong>,而并沒(méi)有討論對(duì)某些模塊梯度流的比例縮放,或者說(shuō)最細(xì)粒度的梯度流控制,后者我們將會(huì)在后文中討論。
一般來(lái)說(shuō),截?cái)嗵荻攘骺梢杂袔追N思路:
停止計(jì)算某個(gè)模塊的梯度,在優(yōu)化過(guò)程中這個(gè)模塊還是會(huì)被考慮更新,然而因?yàn)樘荻纫呀?jīng)被截?cái)嗔?#xff0c;因此不能被更新。
- 設(shè)置tensor.detach():完全截?cái)嘀暗奶荻攘?/li>
- 設(shè)置參數(shù)的requires_grad屬性:單純不計(jì)算當(dāng)前設(shè)置參數(shù)的梯度,不影響梯度流
- torch.no_grad():效果類(lèi)似于設(shè)置參數(shù)的requires_grad屬性
在優(yōu)化器中設(shè)置不更新某個(gè)模塊的參數(shù),這個(gè)模塊的參數(shù)在優(yōu)化過(guò)程中就不會(huì)得到更新,然而這個(gè)模塊的梯度在反向傳播時(shí)仍然可能被計(jì)算。
我們后面分別按照這兩大類(lèi)思路進(jìn)行討論。
停止計(jì)算某個(gè)模塊的梯度
在本大類(lèi)方法中,主要涉及到了tensor.detach()和requires_grad的設(shè)置,這兩種都無(wú)非是對(duì)某些模塊,某些節(jié)點(diǎn)變量設(shè)置了是否需要梯度的選項(xiàng)。
tensor.detach()
tensor.detach()的作用是:
tensor.detach()會(huì)創(chuàng)建一個(gè)與原來(lái)張量共享內(nèi)存空間的一個(gè)新的張量,不同的是,這個(gè)新的張量將不會(huì)有梯度流流過(guò),這個(gè)新的張量就像是從原先的計(jì)算圖中脫離(detach)出來(lái)一樣,對(duì)這個(gè)新的張量進(jìn)行的任何操作都不會(huì)影響到原先的計(jì)算圖了。因此對(duì)此新的張量進(jìn)行的梯度流也不會(huì)流過(guò)原先的計(jì)算圖,從而起到了截?cái)嗟哪康摹?/p>
這樣說(shuō)可能不夠清楚,我們舉個(gè)例子。眾所周知,我們的pytorch是動(dòng)態(tài)計(jì)算圖網(wǎng)絡(luò),正是因?yàn)橛?jì)算圖的存在,才能實(shí)現(xiàn)自動(dòng)求導(dǎo)機(jī)制。考慮一個(gè)表達(dá)式:
如果用計(jì)算圖表示則如Fig 2.1所示。
Fig 2.1 計(jì)算圖示例
考慮在這個(gè)式子的基礎(chǔ)上,加上一個(gè)分支:
那么計(jì)算圖就變成了:
Fig 2.2 添加了新的分支后的計(jì)算圖
如果我們不detach() 中間的變量z,分別對(duì)pq和w進(jìn)行反向傳播梯度,我們會(huì)有:
x?=?torch.tensor(([1.0]),requires_grad=True)y?=?x**2
z?=?2*y
w=?z**3
#?This?is?the?subpath
#?Do?not?use?detach()
p?=?z
q?=?torch.tensor(([2.0]),?requires_grad=True)
pq?=?p*q
pq.backward(retain_graph=True)
w.backward()
print(x.grad)
輸出結(jié)果為 tensor([56.])。我們發(fā)現(xiàn),這個(gè)結(jié)果是吧pq和w的反向傳播結(jié)果都進(jìn)行了考慮的,也就是新增加的分支的反向傳播影響了原先主要枝干的梯度流。這個(gè)時(shí)候我們用detach()可以把p給從原先計(jì)算圖中脫離出來(lái),使得其不會(huì)干擾原先的計(jì)算圖的梯度流,如:
Fig 2.3 用了detach之后的計(jì)算圖
那么,代碼就對(duì)應(yīng)地修改為:
x?=?torch.tensor(([1.0]),requires_grad=True)y?=?x**2
z?=?2*y
w=?z**3
#?detach?it,?so?the?gradient?w.r.t?`p`?does?not?effect?`z`!
p?=?z.detach()
q?=?torch.tensor(([2.0]),?requires_grad=True)
pq?=?p*q
pq.backward(retain_graph=True)
w.backward()
print(x.grad)
這個(gè)時(shí)候,因?yàn)榉种У奶荻攘饕呀?jīng)影響不到原先的計(jì)算圖梯度流了,因此輸出為tensor([48.])。
這只是個(gè)計(jì)算圖的簡(jiǎn)單例子,在實(shí)際模塊中,我們同樣可以這樣用,舉個(gè)GAN的例子,代碼如:
????def?backward_D(self):????????#?Fake
????????#?stop?backprop?to?the?generator?by?detaching?fake_B
????????fake_AB?=?self.fake_B
????????#?fake_AB?=?self.fake_AB_pool.query(torch.cat((self.real_A,?self.fake_B),?1))
????????self.pred_fake?=?self.netD.forward(fake_AB.detach())
????????self.loss_D_fake?=?self.criterionGAN(self.pred_fake,?False)
????????#?Real
????????real_AB?=?self.real_B?#?GroundTruth
????????#?real_AB?=?torch.cat((self.real_A,?self.real_B),?1)
????????self.pred_real?=?self.netD.forward(real_AB)
????????self.loss_D_real?=?self.criterionGAN(self.pred_real,?True)
????????#?Combined?loss
????????self.loss_D?=?(self.loss_D_fake?+?self.loss_D_real)?*?0.5
????????self.loss_D.backward()
????def?backward_G(self):
????????#?First,?G(A)?should?fake?the?discriminator
????????fake_AB?=?self.fake_B
????????pred_fake?=?self.netD.forward(fake_AB)
????????self.loss_G_GAN?=?self.criterionGAN(pred_fake,?True)
????????#?Second,?G(A)?=?B
????????self.loss_G_L1?=?self.criterionL1(self.fake_B,?self.real_B)?*?self.opt.lambda_A
????????self.loss_G?=?self.loss_G_GAN?+?self.loss_G_L1
????????self.loss_G.backward()
????def?forward(self):
????????self.real_A?=?Variable(self.input_A)
????????self.fake_B?=?self.netG.forward(self.real_A)
????????self.real_B?=?Variable(self.input_B)
????#?先調(diào)用 forward, 再 D backward,?更新D之后;?再G backward,?再更新G
????def?optimize_parameters(self):
????????self.forward()
????????self.optimizer_D.zero_grad()
????????self.backward_D()
????????self.optimizer_D.step()
????????self.optimizer_G.zero_grad()
????????self.backward_G()
????????self.optimizer_G.step()
我們注意看第六行,self.pred_fake = self.netD.forward(fake_AB.detach())使得在反向傳播D_loss的時(shí)候不會(huì)更新到self.netG,因?yàn)閒ake_AB是由self.netG生成的,代碼如self.fake_B = self.netG.forward(self.real_A)。
設(shè)置requires_grad
tensor.detach()是截?cái)嗵荻攘鞯囊粋€(gè)好辦法,但是在設(shè)置了detach()的張量之前的所有模塊,梯度流都不能回流了(不包括這個(gè)張量本身,這個(gè)張量已經(jīng)脫離原先的計(jì)算圖了),如以下代碼所示:
x?=?torch.randn(2,?2)x.requires_grad?=?True
lin0?=?nn.Linear(2,?2)
lin1?=?nn.Linear(2,?2)
lin2?=?nn.Linear(2,?2)
lin3?=?nn.Linear(2,?2)
x1?=?lin0(x)
x2?=?lin1(x1)
x2?=?x2.detach()?#?此處設(shè)置了detach,之前的所有梯度流都不會(huì)回傳了
x3?=?lin2(x2)
x4?=?lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)
輸出為:
NoneNone
tensor([[-0.7784,?-0.7018],
????????[-0.4261,?-0.3842]])
tensor([[?0.5509,?-0.0386],
????????[?0.5509,?-0.0386]])
我們發(fā)現(xiàn)lin0.weight.grad和lin0.weight.grad都為None了,因?yàn)橥ㄟ^(guò)脫離中間張量,原先計(jì)算圖已經(jīng)和當(dāng)前回傳的梯度流脫離關(guān)系了。
這樣有時(shí)候不夠理想,因?yàn)槲覀兛赡艽嬖谥恍枰承┲虚g模塊不計(jì)算梯度,但是梯度仍然需要回傳的情況,在這種情況下,如下圖所示,我們可能只需要不計(jì)算B_net的梯度,但是我們又希望計(jì)算A_net和C_net的梯度,這個(gè)時(shí)候怎么辦呢?當(dāng)然,通過(guò)detach()這個(gè)方法是不能用了。
事實(shí)上,我們可以通過(guò)設(shè)置張量的requires_grad屬性來(lái)設(shè)置某個(gè)張量是否計(jì)算梯度,而這個(gè)不會(huì)影響梯度回傳,只會(huì)影響當(dāng)前的張量。修改上面的代碼,我們有:
x?=?torch.randn(2,?2)x.requires_grad?=?True
lin0?=?nn.Linear(2,?2)
lin1?=?nn.Linear(2,?2)
lin2?=?nn.Linear(2,?2)
lin3?=?nn.Linear(2,?2)
x1?=?lin0(x)
x2?=?lin1(x1)
for?p?in?lin2.parameters():
????p.requires_grad?=?False
x3?=?lin2(x2)
x4?=?lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)
輸出為:
tensor([[-0.0117,??0.9976],????????[-0.0080,??0.6855]])
tensor([[-0.0075,?-0.0521],
????????[-0.0391,?-0.2708]])
None
tensor([[0.0523,?0.5429],
????????[0.0523,?0.5429]])
啊哈,正是我們想要的結(jié)果,只有設(shè)置了requires_grad=False的模塊沒(méi)有計(jì)算梯度,但是梯度流又能夠回傳。
另外,設(shè)置requires_grad經(jīng)常用在對(duì)輸入變量和輸入的標(biāo)簽進(jìn)行新建的時(shí)候使用,如:
for?mat,label?in?dataloader:????mat?=?Variable(mat,?requires_grad=False)
????label?=?Variable(mat,requires_grad=False)
????...
當(dāng)然,通過(guò)把所有前端網(wǎng)絡(luò)都設(shè)置requires_grad=False,我們可以實(shí)現(xiàn)類(lèi)似于detach()的效果,也就是把該節(jié)點(diǎn)之前的所有梯度流回傳截?cái)唷R訴GG16為例子,如果我們只需要訓(xùn)練其分類(lèi)器,而固定住其特征提取器網(wǎng)絡(luò)的參數(shù),我們可以采用將前端網(wǎng)絡(luò)的所有參數(shù)的requires_grad設(shè)置為False,因?yàn)檫@個(gè)時(shí)候完全不需要梯度流的回傳,只需要前向計(jì)算即可。代碼如:
model?=?torchvision.models.vgg16(pretrained=True)for?param?in?model.features.parameters():
????param.requires_grad?=?False
torch.no_grad()
在對(duì)訓(xùn)練好的模型進(jìn)行評(píng)估測(cè)試時(shí),我們同樣不需要訓(xùn)練,自然也不需要梯度流信息了。我們可以把所有參數(shù)的requires_grad屬性設(shè)置為False,事實(shí)上,我們常用torch.no_grad()上下文管理器達(dá)到這個(gè)目的。即便輸入的張量屬性是requires_grad=True, ? torch.no_grad()可以將所有的中間計(jì)算結(jié)果的該屬性臨時(shí)轉(zhuǎn)變?yōu)镕alse。
如例子所示:
x?=?torch.randn(3,?requires_grad=True)x1?=?(x**2)
print(x.requires_grad)
print(x1.requires_grad)
with?torch.no_grad():
????x2?=?(x**2)
????print(x1.requires_grad)
????print(x2.requires_grad)
輸出為:
TrueTrue
True
False
注意到只是在torch.no_grad()上下文管理器范圍內(nèi)計(jì)算的中間變量的屬性requires_grad才會(huì)被轉(zhuǎn)變?yōu)镕alse,在該管理器外面計(jì)算的并不會(huì)變化。
不過(guò)和單純手動(dòng)設(shè)置requires_grad=False不同的是,在設(shè)置了torch.no_grad()之前的層是不能回傳梯度的,延續(xù)之前的例子如:
x?=?torch.randn(2,?2)x.requires_grad?=?True
lin0?=?nn.Linear(2,?2)
lin1?=?nn.Linear(2,?2)
lin2?=?nn.Linear(2,?2)
lin3?=?nn.Linear(2,?2)
x1?=?lin0(x)
with?torch.no_grad():
????x2?=?lin1(x1)
x3?=?lin2(x2)
x4?=?lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)
輸出為:
NoneNone
tensor([[-0.0926,?-0.0945],
????????[-0.2793,?-0.2851]])
tensor([[-0.5216,??0.8088],
????????[-0.5216,??0.8088]])
此處如果我們打印lin1.weight.requires_grad我們會(huì)發(fā)現(xiàn)其為T(mén)rue,但是其中間變量x2.requires_grad=False。
一般來(lái)說(shuō)在實(shí)踐中,我們的torch.no_grad()通常會(huì)在測(cè)試模型的時(shí)候使用,而不會(huì)選擇在選擇性訓(xùn)練某些模塊時(shí)使用[1],例子如:
model.train()#?here?train?the?model,?just?skip?the?codes
model.eval()?#?here?we?start?to?evaluate?the?model
with?torch.no_grad():
?for?each?in?eval_data:
??data,?label?=?each
??logit?=?model(data)
??...?#?here?we?just?skip?the?codes
注意
通過(guò)設(shè)置屬性requires_grad=False的方法(包括torch.no_grad())很多時(shí)候可以避免保存中間計(jì)算的buffer,從而減少對(duì)內(nèi)存的需求,但是這個(gè)也是視情況而定的,比如如[2]的所示
graph LR;input-->A_net;
A_net-->B_net;
B_net-->C_net;
如果我們不需要A_net的梯度,我們?cè)O(shè)置所有A_net的requires_grad=False,因?yàn)楹罄m(xù)的B_net和C_net的梯度流并不依賴(lài)于A_net,因此不計(jì)算A_net的梯度流意味著不需要保存這個(gè)中間計(jì)算結(jié)果,因此減少了內(nèi)存。
但是如果我們不需要的是B_net的梯度,而需要A_net和C_net的梯度,那么問(wèn)題就不一樣了,因?yàn)锳_net梯度依賴(lài)于B_net的梯度,就算不計(jì)算B_net的梯度,也需要保存回傳過(guò)程中B_net中間計(jì)算的結(jié)果,因此內(nèi)存并不會(huì)被減少。
但是通過(guò)tensor.detach()的方法并不會(huì)減少內(nèi)存使用,這一點(diǎn)需要注意。
設(shè)置優(yōu)化器的更新列表
這個(gè)方法更為直接,即便某個(gè)模塊進(jìn)行了梯度計(jì)算,我只需要在優(yōu)化器中指定不更新該模塊的參數(shù),那么這個(gè)模塊就和沒(méi)有計(jì)算梯度有著同樣的效果了。如以下代碼所示:
class?model(nn.Module):????def?__init__(self):
????????super().__init__()
????????self.model_1?=?nn.linear(10,10)
????????self.model_2?=?nn.linear(10,20)
????????self.fc?=?nn.linear(20,2)
????????self.relu?=?nn.ReLU()
???????
????def?foward(inputv):
????????h?=?self.model_1(inputv)
????????h?=?self.relu(h)
????????h?=?self.model_2(inputv)
????????h?=?self.relu(h)
????????return?self.fc(h)
在設(shè)置優(yōu)化器時(shí),我們只需要更新fc層和model_2層,那么則是:
curr_model?=?model()opt_list?=?list(curr_model.fc.parameters())+list(curr_model.model_2.parameters())
optimizer?=?torch.optim.SGD(opt_list,?lr=1e-4)
當(dāng)然你也可以通過(guò)以下的方法去設(shè)置每一個(gè)層的學(xué)習(xí)率來(lái)避免不需要更新的層的更新[3]:
optim.SGD([????????????????{'params':?model.model_1.parameters()},
????????????????{'params':?model.mode_2.parameters(),?'lr':?0},
?????????{'params':?model.fc.parameters(),?'lr':?0}
????????????],?lr=1e-2,?momentum=0.9)
這種方法不需要更改模型本身結(jié)構(gòu),也不需要添加模型的額外節(jié)點(diǎn),但是需要保存梯度的中間變量,并且將會(huì)計(jì)算不需要計(jì)算的模塊的梯度(即便最后優(yōu)化的時(shí)候不考慮更新),這樣浪費(fèi)了內(nèi)存和計(jì)算時(shí)間。
Reference
[1]. https://blog.csdn.net/LoseInVain/article/details/82916163
[2]. https://discuss.pytorch.org/t/requires-grad-false-does-not-save-memory/21936
[3]. https://pytorch.org/docs/stable/optim.html#module-torch.optim
總結(jié)
以上是生活随笔為你收集整理的更新fielddata为true_在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新...的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 德州计算机速成班培训,德州办公软件培训速
- 下一篇: 点击修改表格背景色