PyTorch-混合精度训练
簡(jiǎn)介
自動(dòng)混合精度訓(xùn)練(auto Mixed Precision,amp)是深度學(xué)習(xí)比較流行的一個(gè)訓(xùn)練技巧,它可以大幅度降低訓(xùn)練的成本并提高訓(xùn)練的速度,因此在競(jìng)賽中受到了較多的關(guān)注。此前,比較流行的混合精度訓(xùn)練工具是由NVIDIA開(kāi)發(fā)的A PyTorch Extension(Apex),它能夠以非常簡(jiǎn)單的API支持自動(dòng)混合精度訓(xùn)練,不過(guò),PyTorch從1.6版本開(kāi)始已經(jīng)內(nèi)置了amp模塊,本文簡(jiǎn)單介紹其使用。
自動(dòng)混合精度(AMP)
首先來(lái)聊聊自動(dòng)混合精度的由來(lái)。下圖是常見(jiàn)的浮點(diǎn)數(shù)表示形式,它表示單精度浮點(diǎn)數(shù),在編程語(yǔ)言中的體現(xiàn)是float型,顯然從圖中不難看出它需要4個(gè)byte也就是32bit來(lái)進(jìn)行存儲(chǔ)。深度學(xué)習(xí)的模型數(shù)據(jù)均采用float32進(jìn)行表示,這就帶來(lái)了兩個(gè)問(wèn)題:模型size大,對(duì)顯存要求高;32位計(jì)算慢,導(dǎo)致模型訓(xùn)練和推理速度慢。
那么半精度是什么呢,顧名思義,它只用16位即2byte來(lái)進(jìn)行表示,較小的存儲(chǔ)占用以及較快的運(yùn)算速度可以緩解上面32位浮點(diǎn)數(shù)的兩個(gè)主要問(wèn)題,因此半精度會(huì)帶來(lái)下面的一些優(yōu)勢(shì):
- 顯存占用更少,模型只有32位的一半存儲(chǔ)占用,這也可以使用更大的batch size以適應(yīng)一些對(duì)大批尺寸有需求的結(jié)構(gòu),如Batch Normalization;
- 計(jì)算速度快,float16的計(jì)算吞吐量可以達(dá)到float32的2-8倍左右,且隨著NVIDIA張量核心的普及,使用半精度計(jì)算已經(jīng)比較成熟,它會(huì)是未來(lái)深度學(xué)習(xí)計(jì)算的一個(gè)重要趨勢(shì)。
那么,半精度有沒(méi)有什么問(wèn)題呢?其實(shí)也是有著很致命的問(wèn)題的,主要是移除錯(cuò)誤和舍入誤差兩個(gè)方面,具體可以參考這篇文章,作者解析的很好,我這里就簡(jiǎn)單復(fù)述一下。
溢出錯(cuò)誤
FP16的數(shù)值表示范圍比FP32的表示范圍小很多,因此在計(jì)算過(guò)程中很容易出現(xiàn)上溢出(overflow)和下溢出(underflow)問(wèn)題,溢出后會(huì)出現(xiàn)梯度nan問(wèn)題,導(dǎo)致模型無(wú)法正確更新,嚴(yán)重影響網(wǎng)絡(luò)的收斂。而且,深度模型訓(xùn)練,由于激活函數(shù)的梯度往往比權(quán)重的梯度要小,更容易出現(xiàn)的是下溢出問(wèn)題。
舍入誤差
舍入誤差(Rounding Error)指的是當(dāng)梯度過(guò)小,小于當(dāng)前區(qū)間內(nèi)的最小間隔時(shí),該次梯度更新可能會(huì)失敗。上面說(shuō)的知乎文章的作者用來(lái)一張很形象的圖進(jìn)行解釋,具體如下,意思是說(shuō)在2?32^{-3}2?3到2?22^{-2}2?2之間,2?32^{-3}2?3每次變大都會(huì)至少加上2?132^{-13}2?13,顯然,梯度還在這個(gè)間隔內(nèi),因此更新是失敗的。
那么這兩個(gè)問(wèn)題是如何解決的呢,思路來(lái)自于NVIDIA和百度合作的論文,我這里簡(jiǎn)述一下方法:混合精度訓(xùn)練和損失縮放。前者的思路是在內(nèi)存中使用FP16做儲(chǔ)存和乘法運(yùn)算以加速計(jì)算,用FP32做累加運(yùn)算以避免舍入誤差,這樣就緩解了舍入誤差的問(wèn)題;后者則是針對(duì)梯度值太小從而下溢出的問(wèn)題,它的思想是:反向傳播前,將損失變化手動(dòng)增大2k2^k2k倍,因此反向傳播時(shí)得到的中間變量(激活函數(shù)梯度)則不會(huì)溢出;反向傳播后,將權(quán)重梯度縮小2k2^k2k倍,恢復(fù)正常值。
研究人員通過(guò)引入FP32進(jìn)行混合精度訓(xùn)練以及通過(guò)損失縮放來(lái)解決FP16的不足,從而實(shí)現(xiàn)了一套混合精度訓(xùn)練的范式,NVIDIA以此為基礎(chǔ)設(shè)計(jì)了Apex包,不過(guò)Apex的使用本文就不涉及了,下一節(jié)主要關(guān)注如何使用torch.cuda.amp實(shí)現(xiàn)自動(dòng)混合精度訓(xùn)練,不過(guò)這里還需要補(bǔ)充的一點(diǎn)就是目前混合精度訓(xùn)練支持的N卡只有包含Tensor Core的卡,如2080Ti、Titan、Tesla等。
PyTorch自動(dòng)混合精度
PyTorch對(duì)混合精度的支持始于1.6版本,位于torch.cuda.amp模塊下,主要是torch.cuda.amp.autocast和torch.cuda.amp.GradScale兩個(gè)模塊,autocast針對(duì)選定的代碼塊自動(dòng)選取適合的計(jì)算精度,以便在保持模型準(zhǔn)確率的情況下最大化改善訓(xùn)練效率;GradScaler通過(guò)梯度縮放,以最大程度避免使用FP16進(jìn)行運(yùn)算時(shí)的梯度下溢。官方給的使用這兩個(gè)模塊進(jìn)行自動(dòng)精度訓(xùn)練的示例代碼鏈接給出,我對(duì)其示例解析如下,這就是一般的訓(xùn)練框架。
# 以默認(rèn)精度創(chuàng)建模型和優(yōu)化器
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)# 創(chuàng)建梯度縮放器
scaler = GradScaler()for epoch in epochs:for input, target in data:optimizer.zero_grad()# 通過(guò)自動(dòng)類(lèi)型轉(zhuǎn)換進(jìn)行前向傳播with autocast():output = model(input)loss = loss_fn(output, target)# 縮放大損失,反向傳播不建議放到autocast下,它默認(rèn)和前向采用相同的計(jì)算精度scaler.scale(loss).backward()# 先反縮放梯度,若反縮后梯度不是inf或者nan,則用于權(quán)重更新scaler.step(optimizer)# 更新縮放器scaler.update()
下面我以簡(jiǎn)單的MNIST任務(wù)做測(cè)試,使用的顯卡為RTX 3090,代碼如下。該代碼段中只包含核心的訓(xùn)練模塊,模型的定義和數(shù)據(jù)集的加載熟悉PyTorch的應(yīng)該不難自行補(bǔ)充。
model = Model()
model = model.cuda()
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())n_epochs = 30
start = time.time()
for epoch in range(n_epochs):total_loss, correct, total = 0.0, 0, 0model.train()for step, data in enumerate(data_loader_train):x_train, y_train = datax_train, y_train = x_train.cuda(), y_train.cuda()outputs = model(x_train)_, pred = torch.max(outputs, 1)loss = loss_fn(outputs, y_train)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item()total += len(y_train)correct += torch.sum(pred == y_train).item()print("epoch {} loss {} acc {}".format(epoch, total_loss, correct / total))
我這里采用的是一個(gè)很小的模型,又是一個(gè)很簡(jiǎn)單的任務(wù),因此模型都是很快收斂,因此精度上沒(méi)有什么明顯的區(qū)別,不過(guò)如果是訓(xùn)練大型模型的話,有人已經(jīng)用實(shí)驗(yàn)證明,內(nèi)置amp和apex庫(kù)都會(huì)有精度下降,不過(guò)amp效果更好一些,下降較少。上面的loss變化圖也是非常類(lèi)似的。
再來(lái)看存儲(chǔ)方面,顯存縮減在這個(gè)任務(wù)中的表現(xiàn)不是特別明顯,因?yàn)檫@個(gè)任務(wù)的參數(shù)量不多,前后向過(guò)程中的FP16存儲(chǔ)節(jié)省不明顯,而因?yàn)橐肓艘恍┛截愔?lèi)的,反而使得顯存略有上升,實(shí)際的任務(wù)中,這種開(kāi)銷(xiāo)肯定遠(yuǎn)小于FP32的開(kāi)銷(xiāo)的。
最后,不妨看一下使用混合精度最關(guān)心的速度問(wèn)題,實(shí)際上混合精度確實(shí)會(huì)帶來(lái)一些速度上的優(yōu)勢(shì),一些官方的大模型如BERT等訓(xùn)練速度提高了2-3倍,這對(duì)于工業(yè)界的需求來(lái)說(shuō),啟發(fā)還是比較多的。
總結(jié)
混合精度計(jì)算是未來(lái)深度學(xué)習(xí)發(fā)展的重要方向,很受工業(yè)界的關(guān)注,PyTorch從1.6版本開(kāi)始默認(rèn)支持amp,雖然現(xiàn)在還不是特別完善,但以后一定會(huì)越來(lái)越好,因此熟悉自動(dòng)混合精度的用法還是有必要的。
超強(qiáng)干貨來(lái)襲 云風(fēng)專訪:近40年碼齡,通宵達(dá)旦的技術(shù)人生
總結(jié)
以上是生活随笔為你收集整理的PyTorch-混合精度训练的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 五步法颈椎病自我按摩图解
- 下一篇: 线性回归,logistic回归和一般回归