把显存用在刀刃上!17 种 pytorch 节约显存技巧
引導(dǎo)
- 1. 顯存都用在哪兒了?
- 2. 技巧 1:使用就地操作
- 3. 技巧 2:避免中間變量
- 4. 技巧 3:優(yōu)化網(wǎng)絡(luò)模型
- 5. 技巧 4:減小 BATCH_SIZE
- 6. 技巧 5:拆分 BATCH
- 7. 技巧 6:降低 PATCH_SIZE
- 8. 技巧 7:優(yōu)化損失求和
- 9. 技巧 8:調(diào)整訓(xùn)練精度
- 10. 技巧 9:分割訓(xùn)練過程
- 11. 技巧10:清理內(nèi)存垃圾
- 12. 技巧11:使用梯度累積
- 13. 技巧12:清除不必要梯度
- 14. 技巧13:周期清理顯存
- 15. 技巧14:多使用下采樣
- 16. 技巧15:刪除無用變量
- 17. 技巧16:改變優(yōu)化器
- 18. 終極技巧
1. 顯存都用在哪兒了?
一般在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時,顯存主要被網(wǎng)絡(luò)模型和中間變量占用。
- 網(wǎng)絡(luò)模型中的卷積層,全連接層和標(biāo)準(zhǔn)化層等的參數(shù)占用顯存,而諸如激活層和池化層等本質(zhì)上是不占用顯存的。
- 中間變量包括特征圖和優(yōu)化器等,是消耗顯存最多的部分。
- 其實(shí) pytorch 本身也占用一些顯存的,但占用不多,以下方法大致按照推薦的優(yōu)先順序。
2. 技巧 1:使用就地操作
就地操作 (inplace) 字面理解就是在原地對變量進(jìn)行操作,對應(yīng)到 pytorch 中就是在原內(nèi)存上對變量進(jìn)行操作而不申請新的內(nèi)存空間,從而減少對內(nèi)存的使用。具體來說就地操作包括三個方面的實(shí)現(xiàn)途徑:
- 使用將 inplace 屬性定義為 True 的激活函數(shù),如 nn.ReLU(inplace=True)
- 使用 pytorch 帶有就地操作的方法,一般是方法名后跟一個下劃線 “_”,如 tensor.add_(),tensor.scatter_(),F.relu_()
- 使用就地操作的運(yùn)算符,如 y += x,y *= x
3. 技巧 2:避免中間變量
在自定義網(wǎng)絡(luò)結(jié)構(gòu)的成員方法 forward 函數(shù)里,避免使用不必要的中間變量,盡量在之前已申請的內(nèi)存里進(jìn)行操作,比如下面的代碼就使用太多中間變量,占用大量不必要的顯存:
def forward(self, x):x0 = self.conv0(x) # 輸入層x1 = F.relu_(self.conv1(x0) + x0)x2 = F.relu_(self.conv2(x1) + x1)x3 = F.relu_(self.conv3(x2) + x2)x4 = F.relu_(self.conv4(x3) + x3)x5 = F.relu_(self.conv5(x4) + x4)x6 = self.conv(x5) # 輸出層return x6為了減少顯存占用,可以將上述 forward 函數(shù)修改如下:
def forward(self, x):x = self.conv0(x) # 輸入層x = F.relu_(self.conv1(x) + x)x = F.relu_(self.conv2(x) + x)x = F.relu_(self.conv3(x) + x)x = F.relu_(self.conv4(x) + x)x = F.relu_(self.conv5(x) + x)x = self.conv(x) # 輸出層return x上述兩段代碼實(shí)現(xiàn)的功能是一樣的,但對顯存的占用卻相去甚遠(yuǎn),后者能節(jié)省前者占用顯存的接近 90% 之多。
4. 技巧 3:優(yōu)化網(wǎng)絡(luò)模型
網(wǎng)絡(luò)模型對顯存的占用主要指的就是卷積層,全連接層和標(biāo)準(zhǔn)化層等的參數(shù),具體優(yōu)化途徑包括但不限于:
- 減少卷積核數(shù)量 (=減少輸出特征圖通道數(shù))
- 不使用全連接層
- 全局池化 nn.AdaptiveAvgPool2d() 代替全連接層 nn.Linear()
- 不使用標(biāo)準(zhǔn)化層
- 跳躍連接跨度不要太大太多 (避免產(chǎn)生大量中間變量)
5. 技巧 4:減小 BATCH_SIZE
- 在訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)時,epoch 代表的是數(shù)據(jù)整體進(jìn)行訓(xùn)練的次數(shù),batch 代表將一個 epoch 拆分為 batch_size 批來參與訓(xùn)練。
- 減小 batch_size 是一個減小顯存占用的慣用技巧,在訓(xùn)練時顯存不夠一般優(yōu)先減小 batch_size ,但 batch_size 不能無限變小,太大會導(dǎo)致網(wǎng)絡(luò)不穩(wěn)定,太小會導(dǎo)致網(wǎng)絡(luò)不收斂。
6. 技巧 5:拆分 BATCH
拆分 batch 跟技巧 4 中減小 batch_size 本質(zhì)是不一樣的, 這種拆分 batch 的操作可以理解為將兩次訓(xùn)練的損失相加再反向傳播,但減小 batch_size 的操作是訓(xùn)練一次反向傳播一次。拆分 batch 操作可以理解為三個步驟,假設(shè)原來 batch 的大小 batch_size=64:
- 將 batch 拆分為兩個 batch_size=32 的小 batch
- 分別輸入網(wǎng)絡(luò)與目標(biāo)值計算損失,將得到的損失相加
- 進(jìn)行反向傳播
7. 技巧 6:降低 PATCH_SIZE
- 在卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練中,patch_size 指的是輸入神經(jīng)網(wǎng)絡(luò)的圖像大小,即(H*W)。
- 網(wǎng)絡(luò)輸入 patch 的大小對于后續(xù)特征圖的大小等影響非常大,訓(xùn)練時可能采用諸如 [64*64],[128*128] 等大小的 patch,如果顯存不足可以進(jìn)一步縮小 patch 的大小,比如 [32*32],[16*16]。
- 但這種方法存在問題,可能極大地影響網(wǎng)絡(luò)的泛化能力,在裁剪的時候一定要注意在原圖上隨機(jī)裁剪,一般不建議。
8. 技巧 7:優(yōu)化損失求和
一個 batch 訓(xùn)練結(jié)束會得到相應(yīng)的一個損失值,如果要計算一個 epoch 的損失就需要累加之前產(chǎn)生的所有 batch 損失,但之前的 batch 損失在 GPU 中占用顯存,直接累加得到的 epoch 損失也會在 GPU 中占用顯存,可以通過如下方法進(jìn)行優(yōu)化:
epoch_loss += batch_loss.detach().item() # epoch 損失上邊代碼的效果就是首先解除 batch_loss 張量的 GPU 占用,將張量中的數(shù)據(jù)取出再進(jìn)行累加。
9. 技巧 8:調(diào)整訓(xùn)練精度
- 降低訓(xùn)練精度
pytorch 中訓(xùn)練神經(jīng)網(wǎng)絡(luò)時浮點(diǎn)數(shù)默認(rèn)使用 32 位浮點(diǎn)型數(shù)據(jù),在訓(xùn)練對于精度要求不是很高的網(wǎng)絡(luò)時可以改為 16 位浮點(diǎn)型數(shù)據(jù)進(jìn)行訓(xùn)練,但要注意同時將數(shù)據(jù)和網(wǎng)絡(luò)模型都轉(zhuǎn)為 16 位浮點(diǎn)型數(shù)據(jù),否則會報錯。降低浮點(diǎn)型數(shù)據(jù)的操作實(shí)現(xiàn)過程非常簡單,但如果優(yōu)化器選擇 Adam 時可能會報錯,選擇 SGD 優(yōu)化器則不會報錯,具體操作步驟如下:
- 混合精度訓(xùn)練
混合精度訓(xùn)練指的是用 GPU 訓(xùn)練網(wǎng)絡(luò)時,相關(guān)數(shù)據(jù)在內(nèi)存中用半精度做儲存和乘法來加速計算,用全精度進(jìn)行累加避免舍入誤差,這種混合經(jīng)度訓(xùn)練的方法可以令訓(xùn)練時間減少一半左右,也可以很大程度上減小顯存占用。在 pytorch1.6 之前多使用 NVIDIA 提供的 apex 庫進(jìn)行訓(xùn)練,之后多使用 pytorch 自帶的 amp 庫,實(shí)例代碼如下:
10. 技巧 9:分割訓(xùn)練過程
- 如果訓(xùn)練的網(wǎng)絡(luò)非常深,比如 resnet101 就是一個很深的網(wǎng)絡(luò),直接訓(xùn)練深度神經(jīng)網(wǎng)絡(luò)對顯存的要求非常高,一般一次無法直接訓(xùn)練整個網(wǎng)絡(luò)。在這種情況下,可以將復(fù)雜網(wǎng)絡(luò)分割為兩個小網(wǎng)絡(luò),分別進(jìn)行訓(xùn)練。
- checkpoint 是 pytorch 中一種用時間換空間的顯存不足解決方案,這種方法本質(zhì)上減少的是參與一次訓(xùn)練網(wǎng)絡(luò)整體的參數(shù)量,如下是一個實(shí)例代碼。
- 使用 checkpoint 進(jìn)行網(wǎng)絡(luò)訓(xùn)練要求輸入屬性 requires_grad=True ,在給出的代碼中將一個網(wǎng)絡(luò)結(jié)構(gòu)拆分為 3 個子網(wǎng)絡(luò)進(jìn)行訓(xùn)練,對于沒有 nn.Sequential() 構(gòu)建神經(jīng)網(wǎng)絡(luò)的情況無非就是自定義的子網(wǎng)絡(luò)里多幾項(xiàng),或者像例子中一樣單獨(dú)構(gòu)建網(wǎng)絡(luò)塊。
- 對于由 nn.Sequential() 包含的大網(wǎng)絡(luò)塊 (小網(wǎng)絡(luò)塊時沒必要),可以使用 checkpoint_sequential 包來簡化實(shí)現(xiàn),具體實(shí)現(xiàn)過程如下:
11. 技巧10:清理內(nèi)存垃圾
- python 中定義的變量一般在使用結(jié)束時不會立即釋放資源,在訓(xùn)練循環(huán)開始時可以利用如下代碼來回收內(nèi)存垃圾。
12. 技巧11:使用梯度累積
- 由于顯存大小的限制,訓(xùn)練大型網(wǎng)絡(luò)模型時無法使用較大的 batch_size ,而一般較大的 batch_size 能令網(wǎng)絡(luò)模型更快收斂。
- 梯度累積就是將多個 batch 計算得到的損失平均后累積再進(jìn)行反向傳播,類似于技巧 5 中拆分 batch 的思想(但技巧 5 是將大 batch 拆小,訓(xùn)練的依舊是大 batch,而梯度累積訓(xùn)練的是小 batch)。
- 可以采用梯度累積的思想來模擬較大 batch_size 可以達(dá)到的效果,具體實(shí)現(xiàn)代碼如下:
13. 技巧12:清除不必要梯度
在運(yùn)行測試程序時不涉及到與梯度有關(guān)的操作,因此可以清楚不必要的梯度以節(jié)約顯存,具體包括但不限于如下操作:
- 用代碼 model.eval() 將模型置于測試狀態(tài),不啟用標(biāo)準(zhǔn)化和隨機(jī)舍棄神經(jīng)元等操作。
- 測試代碼放入上下文管理器 with torch.no_grad(): 中,不進(jìn)行圖構(gòu)建等操作。
- 在訓(xùn)練或測試每次循環(huán)開始時加梯度清零操作
14. 技巧13:周期清理顯存
- 同理也可以在訓(xùn)練每次循環(huán)開始時利用 pytorch 自帶清理顯存的代碼來釋放不用的顯存資源。
執(zhí)行這條語句釋放的顯存資源在用 Nvidia-smi 命令查看時體現(xiàn)不出,但確實(shí)是已經(jīng)釋放。其實(shí) pytorch 原則上是如果變量不再被引用會自動釋放,所以這條語句可能沒啥用,但個人覺得多少有點(diǎn)用。
15. 技巧14:多使用下采樣
下采樣從實(shí)現(xiàn)上來看類似池化,但不限于池化,其實(shí)也可以用步長大于 1 來代替池化等操作來進(jìn)行下采樣。從結(jié)果上來看就是通過下采樣得到的特征圖會縮小,特征圖縮小自然參數(shù)量減少,進(jìn)而節(jié)約顯存,可以用如下兩種方式實(shí)現(xiàn):
nn.Conv2d(32, 32, 3, 2, 1) # 步長大于 1 下采樣nn.Conv2d(32, 32, 3, 1, 1) # 卷積核接池化下采樣 nn.MaxPool2d(2, 2)16. 技巧15:刪除無用變量
del 功能是徹底刪除一個變量,要再使用必須重新創(chuàng)建,注意 del 刪除的是一個變量而不是從內(nèi)存中刪除一個數(shù)據(jù),這個數(shù)據(jù)有可能也被別的變量在引用,實(shí)現(xiàn)方法很簡單,比如:
def forward(self, x):input_ = xx = F.relu_(self.conv1(x) + input_)x = F.relu_(self.conv2(x) + input_)x = F.relu_(self.conv3(x) + input_)del input_ # 刪除變量 input_x = self.conv4(x) # 輸出層return x17. 技巧16:改變優(yōu)化器
進(jìn)行網(wǎng)絡(luò)訓(xùn)練時比較常用的優(yōu)化器是 SGD 和 Adam,拋開訓(xùn)練最后的效果來談,SGD 對于顯存的占用相比 Adam 而言是比較小的,實(shí)在沒有辦法時可以嘗試改變參數(shù)優(yōu)化算法,兩種優(yōu)化算法的調(diào)用是相似的:
import torch.optim as optim from torchvision.models import resnet18LEARNING_RATE = 1e-3 # 學(xué)習(xí)率 myNet = resnet18().cuda() # 實(shí)例化網(wǎng)絡(luò)optimizer_adam = optim.Adam(myNet.parameters(), lr=LEAENING_RATE) # adam 網(wǎng)絡(luò)參數(shù)優(yōu)化算法 optimizer_sgd = optim.SGD(myNet.parameters(), lr=LEAENING_RATE) # sgd 網(wǎng)絡(luò)參數(shù)優(yōu)化算法18. 終極技巧
購買顯存夠大的顯卡,一塊不行那就 多來幾塊。
總結(jié)
以上是生活随笔為你收集整理的把显存用在刀刃上!17 种 pytorch 节约显存技巧的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: jquery实现注册表单验证
- 下一篇: GitLab Admin Area