一训练就显存爆炸?Facebook 推出 8 比特优化器,两行代码拯救你的显存!
文 | jxyxiangyu
編 | 小軼
“小夕,小夕!又出來了個(gè) SOTA 模型!趕緊 follow !”
小夕看了看新模型的參數(shù)量, 然后看了看實(shí)驗(yàn)室服務(wù)器的幾張小破卡。
小夕,陷入了沉默。
自從人們發(fā)現(xiàn)越大的模型性能越好后,神經(jīng)網(wǎng)絡(luò)模型的參數(shù)量就在越來越大的道路上一去不復(fù)返了。從XX-large到GPT3,再到5300億參數(shù)的Megatron Turing-NLG,深度學(xué)習(xí)越來越像是只有財(cái)大氣粗的大公司才能玩得起的玩具。如果,我們想要在實(shí)驗(yàn)室“簡(jiǎn)陋”的環(huán)境下,嘗試更大的模型,有什么行之有效的方法呢?
最近,Facebook 推出了支持 pytorch 的 8 位優(yōu)化器,在減小內(nèi)存占用的同時(shí),竟然還能保持和32位優(yōu)化器相當(dāng)?shù)臏?zhǔn)確性。不得不說 facebook yyds。那么,下面就讓我們一起來看看具體是怎么做的吧。
論文題目:
8-BIT OPTIMIZERS VIA BLOCK-WISE QUANTIZATION
論文鏈接:
https://arxiv-download.xixiaoyao.cn/pdf/2110.02861.pdf
開源鏈接:
https://github.com/facebookresearch/bitsandbytes
量化
在介紹論文作者的解決方法之前,先補(bǔ)充一點(diǎn)關(guān)于量化的基本概念。通常意義上來說,量化是指將信號(hào)的連續(xù)取值近似為有限多個(gè)離散值的過程。具體到計(jì)算機(jī)系統(tǒng),指的是將浮點(diǎn)數(shù)值映射到低bit數(shù)值的操作[1]。
一般來說,我們可以通過以下手段應(yīng)用量化
量化模型參數(shù)來壓縮模型;
量化模型某些層的激活值來減少內(nèi)存占用*;
(注:參數(shù)和梯度也會(huì)占用一定的內(nèi)存空間,但相對(duì)于激活值而言,占用比例不大,一般來說,量化參數(shù)和梯度帶來的內(nèi)存收益沒有量化激活值的大)
▲量化示意圖上圖是 Song Han 在 ICLR'2016 上提出的量化方法。將模型參數(shù)分別聚類到幾個(gè)質(zhì)心,并將參數(shù)量化到對(duì)應(yīng)的質(zhì)心,在更新參數(shù)時(shí),是將同一質(zhì)心對(duì)應(yīng)的梯度累加用于更新該質(zhì)心對(duì)應(yīng)的參數(shù)
可以看到,量化通過將參數(shù)(浮點(diǎn)值)映射到二值、三值或線性量化到一個(gè)區(qū)間(一般是低比特?cái)?shù)值)的方式,減小了模型大小,在某些硬件上面,低比特?cái)?shù)值運(yùn)算速度高于浮點(diǎn)數(shù)值,一定程度上可以加速模型的訓(xùn)練和預(yù)測(cè);除此之外,模型在訓(xùn)練和預(yù)測(cè)的時(shí)候,模型參數(shù)本身只占用了內(nèi)存的一小部分,大部分存儲(chǔ)了模型的激活值,如果將量化應(yīng)用到激活值上,一定程度也減小了內(nèi)存占用,這樣,我們就可以嘗試更大的模型和設(shè)置更大的mini-batch了。
當(dāng)然,量化這么好,也不是沒有缺點(diǎn)的,量化后的模型或多或少會(huì)引入精度損失;并且目前學(xué)術(shù)界多采用 pytorch 框架,好巧不巧的是 pytorch 框架對(duì)量化的支持沒有 tensorflow 好,這總不能為了體驗(yàn)大模型的快感再轉(zhuǎn)到 tensorflow 上面去吧,想想 tensorflow 混亂的 api 就頭疼(╯°Д°)╯ ┻━┻
最近,Facebook 推出了支持 pytorch 的 8 位優(yōu)化器,在減小內(nèi)存占用的同時(shí),竟然還能保持和32位優(yōu)化器相當(dāng)?shù)臏?zhǔn)確性。
狀態(tài)優(yōu)化器
再簡(jiǎn)單介紹下帶有狀態(tài)的優(yōu)化器(stateful optimizer)。和普通的隨機(jī)梯度下降(SGD)相比,為了加速優(yōu)化而提出的帶有梯度統(tǒng)計(jì)信息的優(yōu)化器,就是狀態(tài)優(yōu)化器。常見的例如帶動(dòng)量的 SGD 和 Adam 。計(jì)算公式如下:其中, 和 是平滑因子, 是非常小的常量, 是學(xué)習(xí)率。
作者認(rèn)為,狀態(tài)優(yōu)化器會(huì)維護(hù)歷史梯度數(shù)據(jù),一定程度上占用了內(nèi)存。通過量化這些梯度,可以有效地降低內(nèi)存占用。
非線性量化
前述已經(jīng)介紹了量化就是將信號(hào)的連續(xù)取值近似為有限多個(gè)離散值的過程,在降低模型參數(shù)量的同時(shí),也會(huì)帶來一定的精度損失,為減小精度損失,多采用非線性的量化方式,大致可以歸納為三個(gè)步驟:
對(duì)于輸入張量,計(jì)算歸一化因子;
將張量通過歸一化后,找到在量化空間中距離最近的值;
將量化后的張量的每個(gè)元素的索引存儲(chǔ)下來
那么,我們就可以遍歷存儲(chǔ)的索引并通過下式得到反量化張量:,其中,是反量化映射
為了使不同元素值量級(jí)一致,一般會(huì)將張量歸一化到的區(qū)間范圍,這時(shí),取的是輸入張量中絕對(duì)值的最大值,即,然后通過二分查找的方式找到量化空間中距離該值最近的量化值
動(dòng)態(tài)樹量化
上一節(jié)看到,非線性量化在歸一化時(shí)會(huì)嚴(yán)重依賴輸入張量中的最值,像某些特別大或特別小的異常值,對(duì)量化會(huì)產(chǎn)生較大的精度影響。動(dòng)態(tài)樹量化(dynamic tree quantization)就是一種以較低的量化精度損失處理這種情況的方法。
與浮點(diǎn)數(shù)的存儲(chǔ)方式類似,動(dòng)態(tài)樹量化以這類方式解釋存儲(chǔ)在內(nèi)存中的數(shù)值,以此實(shí)現(xiàn)量化,具體由四部分組成:
首位是符號(hào)位
符號(hào)位后連續(xù)的0的數(shù)量表示指數(shù)大小
再之后的第一位是指示位,如果指示位為1表示后續(xù)剩余的位為線性量化區(qū)域
線性量化區(qū)域
其中,指示位是可以動(dòng)態(tài)移動(dòng)的。通過移動(dòng)指示位,可以表示指數(shù)為或者精度為的數(shù)值,表示范圍為
8位優(yōu)化器
有了前面的知識(shí)鋪墊,下面就可以詳細(xì)地說明作者提出的8位優(yōu)化器了。該8位優(yōu)化器由三部分構(gòu)成:
逐塊量化(block-wise quantization)
動(dòng)態(tài)量化(dynamic quantization)
穩(wěn)定的詞嵌入層(stable embedding layer)
應(yīng)用上述組件,將8位優(yōu)化器的狀態(tài)反量化為32位并更新狀態(tài)和參數(shù),然后將這些狀態(tài)量化回8位進(jìn)行存儲(chǔ)。由于是在寄存器中進(jìn)行8位和32位的轉(zhuǎn)換,一定程度上可以減小內(nèi)存占用并加速訓(xùn)練。
逐塊量化
常見的量化需要將原始的張量在張量級(jí)別歸一化,這樣可能會(huì)引入核之間的多次信息通信和同步,造成額外的時(shí)間開銷,而逐塊量化則是將張量分成多個(gè)小塊并在塊級(jí)別歸一化,減小了核之間的通信開銷,除此之外,還可以將張量元素中的異常值的影響限制在單個(gè)塊中。假設(shè)為有個(gè)元素的張量,分成每個(gè)大小為的塊,那么,可以分成個(gè)塊,對(duì)每個(gè)塊分別做歸一化,歸一化因子為,每個(gè)塊分別通過下式進(jìn)行量化操作:其中,為塊索引,為塊中元素的索引
動(dòng)態(tài)量化
8位優(yōu)化器的動(dòng)態(tài)量化部分是對(duì)前面提到的動(dòng)態(tài)樹量化的擴(kuò)展,對(duì)于像的第二個(gè)狀態(tài)這種嚴(yán)格為正的數(shù)值,符號(hào)位就顯得有些多余,而在語言模型的訓(xùn)練過程中,作者發(fā)現(xiàn)的變化范圍在3~5個(gè)數(shù)量級(jí),小于動(dòng)態(tài)樹量化的7個(gè)數(shù)量級(jí),因此,可以用固定的位將只會(huì)用到的位劃分開,進(jìn)一步減小內(nèi)存占用。對(duì)于其他帶符號(hào)的狀態(tài)張量,則繼續(xù)使用動(dòng)態(tài)樹量化。
穩(wěn)定的詞嵌入層
為了確保nlp任務(wù)中模型的穩(wěn)定訓(xùn)練,作者還添加了穩(wěn)定的詞嵌入層。作者使用Xavier uniform對(duì)詞嵌入層進(jìn)行初始化,并且在與位置向量合并前進(jìn)行層歸一化操作,這樣可以使參數(shù)在初始化和訓(xùn)練期間保持1左右的方差。詞嵌入層的優(yōu)化器狀態(tài)用32位存儲(chǔ),權(quán)重和梯度用16位存儲(chǔ)。
8位優(yōu)化器 vs 32位優(yōu)化器
作者在多個(gè)任務(wù)(包括機(jī)器翻譯、大規(guī)模語言模型的預(yù)訓(xùn)練以及微調(diào)、圖像分類和圖像預(yù)訓(xùn)練以及微調(diào))上比較了8位優(yōu)化器和32位優(yōu)化器的性能,比較的優(yōu)化器包括、或,實(shí)驗(yàn)中,除了將32位優(yōu)化器替換為8位優(yōu)化器外,沒有改動(dòng)超參和權(quán)重、梯度以及激活值的精度。
除了GLUE任務(wù)之外,其余的NLP任務(wù)均使用了作者提出的穩(wěn)定詞嵌入層。為確保實(shí)驗(yàn)結(jié)果的可信度,還在不同隨機(jī)數(shù)種子下多次實(shí)驗(yàn),選取了實(shí)驗(yàn)結(jié)果的中位數(shù)作為最終性能。實(shí)驗(yàn)結(jié)果如下所示:
可以看到,8位優(yōu)化器在多個(gè)任務(wù)上均達(dá)到甚至是超過了32位優(yōu)化器的性能,與此同時(shí),還能大幅減小內(nèi)存開銷并加速訓(xùn)練。此外,作者還列出了在同等顯存大小的條件下,使用8位優(yōu)化器和32位優(yōu)化器可以支持訓(xùn)練的模型??梢哉f,非常貼心了ヾ(●゜ⅴ゜)ノ
消融研究
作者基于語料庫訓(xùn)練了多個(gè)模型,用于研究8位優(yōu)化器中各個(gè)組件的影響。實(shí)驗(yàn)結(jié)果如下:其中,32位優(yōu)化器(baseline)采用的是線性量化。
為測(cè)試優(yōu)化器的穩(wěn)定性,對(duì)于小規(guī)模的模型,作者分別訓(xùn)練了不同的超參數(shù)下的模型,超參數(shù)為 {1e-8, 1e-7, 1e-6}, {0.90, 0.87, 0.93}, {0.999, 0.99, 0.98}以及學(xué)習(xí)率方面的改動(dòng),而對(duì)于超過1B的大規(guī)模模型,則是在相同超參下采用不同的隨機(jī)數(shù)種子多次運(yùn)行。所有的結(jié)果均是選擇可以成功訓(xùn)練完(沒有因梯度爆炸或彌散而無法訓(xùn)練)的模型性能的中值。
可以看出,逐塊量化、動(dòng)態(tài)量化和穩(wěn)定的詞嵌入層對(duì)結(jié)果都有正向影響。
此外,作者還對(duì)比了32位優(yōu)化器和8位優(yōu)化器對(duì)超參的敏感程度,比較了32位和8位優(yōu)化器在、、和的變化下的走勢(shì)
可以看到,8位優(yōu)化器和32位相比,困惑度走勢(shì)基本一致,表明對(duì)超參不敏感,在將32位優(yōu)化器替換為8位優(yōu)化器后,超參不需要進(jìn)一步的調(diào)整
局限性
從實(shí)驗(yàn)結(jié)果可以看出,8位優(yōu)化器完全可以作為32位優(yōu)化器的替代品。當(dāng)然,8位優(yōu)化器也存在一些局限性:
8位優(yōu)化器需要穩(wěn)定的詞嵌入層來達(dá)到32位優(yōu)化器的性能;
8位優(yōu)化器減小內(nèi)存的大小與模型參數(shù)量成正比,對(duì)于像cnn這種激活值比參數(shù)占內(nèi)存多得多的模型,8位優(yōu)化器并沒有特別明顯的內(nèi)存減小,反而更適合transformer這種架構(gòu)的大規(guī)模模型
總結(jié)
不得不說,Facebook的8位優(yōu)化器簡(jiǎn)直是我等“窮困”煉丹黨的福音。現(xiàn)在,8位優(yōu)化器已經(jīng)開源,開源地址已經(jīng)在文章開頭提到。目前,8位優(yōu)化器已經(jīng)支持Adam, AdamW, RMSProp, LARS, LAMB優(yōu)化器。使用時(shí),需要安裝并導(dǎo)入包bitsandbytes-cudaXXX,其中,XXX是本地環(huán)境的cuda工具包版本號(hào),注釋掉原有的優(yōu)化器,調(diào)用8位優(yōu)化器就可以了。
import?bitsandbytes?as?bnb#?adam?=?torch.optim.Adam(model.parameters(),?lr=0.001,?betas=(0.9,?0.995))?#?comment?out?old?optimizer adam?=?bnb.optim.Adam8bit(model.parameters(),?lr=0.001,?betas=(0.9,?0.995))?#?add?bnb?optimizer adam?=?bnb.optim.Adam(model.parameters(),?lr=0.001,?betas=(0.9,?0.995),?optim_bits=8)?#?equivalenttorch.nn.Embedding(...)?->??bnb.nn.StableEmbedding(...)?#?recommended?for?NLP?models據(jù)官網(wǎng)描述,僅僅需要改動(dòng)兩行代碼,就可以節(jié)省75%的內(nèi)存!小伙伴們,還不想抓緊時(shí)間上車體驗(yàn)一下嘛?
▲沒時(shí)間解釋了,快上車后臺(tái)回復(fù)關(guān)鍵詞【入群】
加入賣萌屋NLP/IR/Rec與求職討論群
后臺(tái)回復(fù)關(guān)鍵詞【頂會(huì)】
獲取ACL、CIKM等各大頂會(huì)論文集!
?
[1] 商湯科技SenseTime, 模型量化了解一下?(https://zhuanlan.zhihu.com/p/132561405)
[2] Song, H. , ?H. Mao , and ?W. J. Dally . "Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding." ICLR 2016. (https://arxiv-download.xixiaoyao.cn/pdf/1510.00149.pdf)
總結(jié)
以上是生活随笔為你收集整理的一训练就显存爆炸?Facebook 推出 8 比特优化器,两行代码拯救你的显存!的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Step-by-step to Tran
- 下一篇: 互联网(IT)大厂面试技巧(面经)