硬核NeruIPS 2018最佳论文,一个神经了的常微分方程
機(jī)器之心原創(chuàng)
作者:蔣思源
這是一篇神奇的論文,以前一層一層疊加的神經(jīng)網(wǎng)絡(luò)似乎突然變得連續(xù)了,反向傳播也似乎不再需要一點(diǎn)一點(diǎn)往前傳、一層一層更新參數(shù)了。
在最近結(jié)束的 NeruIPS 2018 中,來自多倫多大學(xué)的陳天琦等研究者成為最佳論文的獲得者。他們提出了一種名為神經(jīng)常微分方程的模型,這是新一類的深度神經(jīng)網(wǎng)絡(luò)。神經(jīng)常微分方程不拘于對(duì)已有架構(gòu)的修修補(bǔ)補(bǔ),它完全從另外一個(gè)角度考慮如何以連續(xù)的方式借助神經(jīng)網(wǎng)絡(luò)對(duì)數(shù)據(jù)建模。在陳天琦的講解下,機(jī)器之心將向各位讀者介紹這一令人興奮的神經(jīng)網(wǎng)絡(luò)新家族。
在與機(jī)器之心的訪談中,陳天琦的導(dǎo)師 David Duvenaud 教授談起這位學(xué)生也是贊不絕口。Duvenaud 教授認(rèn)為陳天琦不僅是位理解能力超強(qiáng)的學(xué)生,鉆研起問題來也相當(dāng)認(rèn)真透徹。他說:「天琦很喜歡提出新想法,他有時(shí)會(huì)在我提出建議一周后再反饋:『老師你之前建議的方法不太合理。但是我研究出另外一套合理的方法,結(jié)果我也做出來了。』」Ducenaud 教授評(píng)價(jià)道,現(xiàn)如今人工智能熱度有增無減,教授能找到優(yōu)秀博士生基本如同「雞生蛋還是蛋生雞」的問題,頂尖學(xué)校的教授通常能快速地招納到博士生,「我很幸運(yùn)地能在事業(yè)起步階段就遇到陳天琦如此優(yōu)秀的學(xué)生。」
本文主要介紹神經(jīng)常微分方程背后的細(xì)想與直觀理解,很多延伸的概念并沒有詳細(xì)解釋,例如大大降低計(jì)算復(fù)雜度的連續(xù)型流模型和官方 PyTorch 代碼實(shí)現(xiàn)等。這一篇文章重點(diǎn)對(duì)比了神經(jīng)常微分方程(ODEnet)與殘差網(wǎng)絡(luò),我們不僅能通過這一部分了解如何從熟悉的 ResNet 演化到 ODEnet,同時(shí)還能還有新模型的前向傳播過程和特點(diǎn)。
其次文章比較關(guān)注 ODEnet 的反向傳播過程,即如何通過解常微分方程直接把梯度求出來。這一部分與傳統(tǒng)的反向傳播有很多不同,因此先理解反向傳播再看源碼可能是更好的選擇。值得注意的是,ODEnet 的反傳只有常數(shù)級(jí)的內(nèi)存占用成本。
ODEnet 的 PyTorch 實(shí)現(xiàn)地址:https://github.com/rtqichen/torchdiffeq
ODEnet 論文地址:https://arxiv.org/abs/1806.07366
如下展示了文章的主要結(jié)構(gòu):
常微分方程
從殘差網(wǎng)絡(luò)到微分方程
? ? 從微分方程到殘差網(wǎng)絡(luò)
? ? 網(wǎng)絡(luò)對(duì)比
神經(jīng)常微分方程
? ?反向傳播
? ? 反向傳播怎么做
連續(xù)型的歸一化流
? ? 變量代換定理
常微分方程
其實(shí)初讀這篇論文,還是有一些疑惑的,因?yàn)楹芏喔拍疃疾皇俏覀兯熘摹R虼巳绻胍私膺@個(gè)模型,那么同學(xué)們,你們首先需要回憶高數(shù)上的微分方程。有了這樣的概念后,我們就能愉快地連續(xù)化神經(jīng)網(wǎng)絡(luò)層級(jí),并構(gòu)建完整的神經(jīng)常微分方程。
常微分方程即只包含單個(gè)自變量 x、未知函數(shù) f(x) 和未知函數(shù)的導(dǎo)數(shù) f'(x) 的等式,所以說 f'(x) = 2x 也算一個(gè)常微分方程。但更常見的可以表示為 df(x)/dx = g(f(x), x),其中 g(f(x), x) 表示由 f(x) 和 x 組成的某個(gè)表達(dá)式,這個(gè)式子是擴(kuò)展一般神經(jīng)網(wǎng)絡(luò)的關(guān)鍵,我們在后面會(huì)討論這個(gè)式子怎么就連續(xù)化了神經(jīng)網(wǎng)絡(luò)層級(jí)。
一般對(duì)于常微分方程,我們希望解出未知的 f(x),例如 f'(x) = 2x 的通解為 f(x)=x^2 +C,其中 C 表示任意常數(shù)。而在工程中更常用數(shù)值解,即給定一個(gè)初值 f(x_0),我們希望解出末值 f(x_1),這樣并不需要解出完整的 f(x),只需要一步步逼近它就行了。
現(xiàn)在回過頭來討論我們熟悉的神經(jīng)網(wǎng)絡(luò),本質(zhì)上不論是全連接、循環(huán)還是卷積網(wǎng)絡(luò),它們都類似于一個(gè)非常復(fù)雜的復(fù)合函數(shù),復(fù)合的次數(shù)就等于層級(jí)的深度。例如兩層全連接網(wǎng)絡(luò)可以表示為 Y=f(f(X, θ1), θ2),因此每一個(gè)神經(jīng)網(wǎng)絡(luò)層級(jí)都類似于萬能函數(shù)逼近器。
因?yàn)檎w是復(fù)合函數(shù),所以很容易接受復(fù)合函數(shù)的求導(dǎo)方法:鏈?zhǔn)椒▌t,并將梯度從最外一層的函數(shù)一點(diǎn)點(diǎn)先向里面層級(jí)的函數(shù)傳遞,并且每傳到一層函數(shù),就可以更新該層的參數(shù) θ。現(xiàn)在問題是,我們前向傳播過后需要保留所有層的激活值,并在沿計(jì)算路徑反傳梯度時(shí)利用這些激活值。這對(duì)內(nèi)存的占用非常大,因此也就限制了深度模型的訓(xùn)練過程。
神經(jīng)常微分方程走了另一條道路,它使用神經(jīng)網(wǎng)絡(luò)參數(shù)化隱藏狀態(tài)的導(dǎo)數(shù),而不是如往常那樣直接參數(shù)化隱藏狀態(tài)。這里參數(shù)化隱藏狀態(tài)的導(dǎo)數(shù)就類似構(gòu)建了連續(xù)性的層級(jí)與參數(shù),而不再是離散的層級(jí)。因此參數(shù)也是一個(gè)連續(xù)的空間,我們不需要再分層傳播梯度與更新參數(shù)。總而言之,神經(jīng)微分方程在前向傳播過程中不儲(chǔ)存任何中間結(jié)果,因此它只需要近似常數(shù)級(jí)的內(nèi)存成本。
從殘差網(wǎng)絡(luò)到微分方程
殘差網(wǎng)絡(luò)是一類特殊的卷積網(wǎng)絡(luò),它通過殘差連接而解決了梯度反傳問題,即當(dāng)神經(jīng)網(wǎng)絡(luò)層級(jí)非常深時(shí),梯度仍然能有效傳回輸入端。下圖為原論文中殘差模塊的結(jié)構(gòu),殘差塊的輸出結(jié)合了輸入信息與內(nèi)部卷積運(yùn)算的輸出信息,這種殘差連接或恒等映射表示深層模型至少不能低于淺層網(wǎng)絡(luò)的準(zhǔn)確度。這樣的殘差模塊堆疊幾十上百個(gè)就是非常深的殘差神經(jīng)網(wǎng)絡(luò)。
如果我們將上面的殘差模塊更加形式化地表示為以下方程:
其中 h_t 是第 t 層隱藏單元的輸出值,f 為通過θ_t 參數(shù)化的神經(jīng)網(wǎng)絡(luò)。該方程式表示上圖的整個(gè)殘差模塊,如果我們其改寫為殘差的形式,即 h_t+1 - h_t = f(h_t, θ_t )。那么我們可以看到神經(jīng)網(wǎng)絡(luò) f 參數(shù)化的是隱藏層之間的殘差,f 同樣不是直接參數(shù)化隱藏層。
ResNet 假設(shè)層級(jí)的離散的,第 t 層到第 t+1 層之間是無定義的。那么如果這中間是有定義的呢?殘差項(xiàng) h_t0 - h_t1 是不是就應(yīng)該非常小,以至于接近無窮小?這里我們少考慮了分母,即殘差項(xiàng)應(yīng)該表示為 (h_t+1 - h_t )/1,分母的 1 表示兩個(gè)離散的層級(jí)之間相差 1。所以再一次考慮層級(jí)間有定義,我們會(huì)發(fā)現(xiàn)殘差項(xiàng)最終會(huì)收斂到隱藏層對(duì) t 的導(dǎo)數(shù),而神經(jīng)網(wǎng)絡(luò)實(shí)際上參數(shù)化的就是這個(gè)導(dǎo)數(shù)。
所以若我們在層級(jí)間加入更多的層,且最終趨向于添加了無窮層時(shí),神經(jīng)網(wǎng)絡(luò)就連續(xù)化了。可以說殘差網(wǎng)絡(luò)其實(shí)就是連續(xù)變換的歐拉離散化,是一個(gè)特例,我們可以將這種連續(xù)變換形式化地表示為一個(gè)常微分方程:
如果從導(dǎo)數(shù)定義的角度來看,當(dāng) t 的變化趨向于無窮小時(shí),隱藏狀態(tài)的變化 dh(t) 可以通過神經(jīng)網(wǎng)絡(luò)建模。當(dāng) t 從初始一點(diǎn)點(diǎn)變化到終止,那么 h(t) 的改變最終就代表著前向傳播結(jié)果。這樣利用神經(jīng)網(wǎng)絡(luò)參數(shù)化隱藏層的導(dǎo)數(shù),就確確實(shí)實(shí)連續(xù)化了神經(jīng)網(wǎng)絡(luò)層級(jí)。
現(xiàn)在若能得出該常微分方程的數(shù)值解,那么就相當(dāng)于完成了前向傳播。具體而言,若 h(0)=X 為輸入圖像,那么終止時(shí)刻的隱藏層輸出 h(T) 就為推斷結(jié)果。這是一個(gè)常微分方程的初值問題,可以直接通過黑箱的常微分方程求解器(ODE Solver)解出來。而這樣的求解器又能控制數(shù)值誤差,因此我們總能在計(jì)算力和模型準(zhǔn)確度之間做權(quán)衡。
形式上來說,現(xiàn)在就需要變換方程 (2) 以求出數(shù)值解,即給定初始狀態(tài) h(t_0) 和神經(jīng)網(wǎng)絡(luò)的情況下求出終止?fàn)顟B(tài) h(t_1):
如上所示,常微分方程的數(shù)值解 h(t_1) 需要求神經(jīng)網(wǎng)絡(luò) f 從 t_0 到 t_1 的積分。我們完全可以利用 ODE solver 解出這個(gè)值,這在數(shù)學(xué)物理領(lǐng)域已經(jīng)有非常成熟的解法,我們只需要將其當(dāng)作一個(gè)黑盒工具使用就行了。
從微分方程到殘差網(wǎng)絡(luò)
前面提到過殘差網(wǎng)絡(luò)是神經(jīng)常微分方程的特例,可以說殘差網(wǎng)絡(luò)是歐拉方法的離散化。兩三百年前解常微分方程的歐拉法非常直觀,即 h(t +Δt) = h(t) + Δt×f(h(t), t)。每當(dāng)隱藏層沿 t 走一小步Δt,新的隱藏層狀態(tài) h(t +Δt) 就應(yīng)該近似在已有的方向上邁一小步。如果這樣一小步一小步從 t_0 走到 t_1,那么就求出了 ODE 的數(shù)值解。
如果我們令 Δt 每次都等于 1,那么離散化的歐拉方法就等于殘差模塊的表達(dá)式 h(t+1) = h(t) + f(h(t), t)。但是歐拉法只是解常微分方程最基礎(chǔ)的方法,它每走一步都會(huì)產(chǎn)生一點(diǎn)誤差,且誤差會(huì)累積起來。近百年來,數(shù)學(xué)家構(gòu)建了很多現(xiàn)代 ODE 求解方法,它們不僅能保證收斂到真實(shí)解,同時(shí)還能控制誤差水平。
陳天琦等研究者構(gòu)建的 ODE 網(wǎng)絡(luò)就使用了一種適應(yīng)性的 ODE solver,它不像歐拉法移動(dòng)固定的步長,相反它會(huì)根據(jù)給定的誤差容忍度選擇適當(dāng)?shù)牟介L逼近真實(shí)解。如下圖所示,左邊的殘差網(wǎng)絡(luò)定義有限轉(zhuǎn)換的離散序列,它從 0 到 1 再到 5 是離散的層級(jí)數(shù),且在每一層通過激活函數(shù)做一次非線性轉(zhuǎn)換。此外,黑色的評(píng)估位置可以視為神經(jīng)元,它會(huì)對(duì)輸入做一次轉(zhuǎn)換以修正傳遞的值。而右側(cè)的 ODE 網(wǎng)絡(luò)定義了一個(gè)向量場,隱藏狀態(tài)會(huì)有一個(gè)連續(xù)的轉(zhuǎn)換,黑色的評(píng)估點(diǎn)也會(huì)根據(jù)誤差容忍度自動(dòng)調(diào)整。
網(wǎng)絡(luò)對(duì)比
在 David 的 Oral 演講中,他以兩段偽代碼展示了 ResNet 與 ODEnet 之間的差別。如下展示了 ResNet 的主要過程,其中 f 可以視為卷積層,ResNet 為整個(gè)模型架構(gòu)。在卷積層 f 中,h 為上一層輸出的特征圖,t 確定目前是第幾個(gè)卷積層。ResNet 中的循環(huán)體為殘差連接,因此該網(wǎng)絡(luò)一共 T 個(gè)殘差模塊,且最終返回第 T 層的輸出值。
????return?nnet(h,?θ_t)
def?resnet(h):
????for?t?in?[1:T]:
????????h?=?h?+?f(h,?t,?θ)
????return?h
相比常見的 ResNet,下面的偽代碼就比較新奇了。首先 f 與前面一樣定義的是神經(jīng)網(wǎng)絡(luò),不過現(xiàn)在它的參數(shù)θ是一個(gè)整體,同時(shí) t 作為獨(dú)立參數(shù)也需要饋送到神經(jīng)網(wǎng)絡(luò)中,這表明層級(jí)之間也是有定義的,它是一種連續(xù)的網(wǎng)絡(luò)。而整個(gè) ODEnet 不需要通過循環(huán)搭建離散的層級(jí),它只要通過 ODE solver 求出 t_1 時(shí)刻的 h 就行了。
????return?nnet([h,?t],?θ)
def?ODEnet(h,?θ):
????return?ODESolver(f,?h,?t_0,?t_1,?θ)
除了計(jì)算過程不一樣,陳天琦等研究者還在 MNSIT 測試了這兩種模型的效果。他們使用帶有 6 個(gè)殘差模塊的 ResNet,以及使用一個(gè) ODE Solver 代替這些殘差模塊的 ODEnet。以下展示了不同網(wǎng)絡(luò)在 MNSIT 上的效果、參數(shù)量、內(nèi)存占用量和計(jì)算復(fù)雜度。
其中單個(gè)隱藏層的 MLP 引用自 LeCun 在 1998 年的研究,其隱藏層只有 300 個(gè)神經(jīng)元,但是 ODEnet 在有相似參數(shù)量的情況下能獲得顯著更好的結(jié)果。上表中 L 表示神經(jīng)網(wǎng)絡(luò)的層級(jí)數(shù),L tilde 表示 ODE Solver 中的評(píng)估次數(shù),它可以近似代表 ODEnet 的「層級(jí)深度」。值得注意的是,ODEnet 只有常數(shù)級(jí)的內(nèi)存占用,這表示不論層級(jí)的深度如何增加,它的內(nèi)存占用基本不會(huì)有太大的變化。
神經(jīng)常微分方程
在與 ResNet 的類比中,我們基本上已經(jīng)了解了 ODEnet 的前向傳播過程。首先輸入數(shù)據(jù) Z(t_0),我們可以通過一個(gè)連續(xù)的轉(zhuǎn)換函數(shù)(神經(jīng)網(wǎng)絡(luò))對(duì)輸入進(jìn)行非線性變換,從而得到 f。隨后 ODESolver 對(duì) f 進(jìn)行積分,再加上初值就可以得到最后的推斷結(jié)果。如下所示,殘差網(wǎng)絡(luò)只不過是用一個(gè)離散的殘差連接代替 ODE Solver。
在前向傳播中,ODEnet 還有幾個(gè)非常重要的性質(zhì),即模型的層級(jí)數(shù)與模型的誤差控制。首先因?yàn)槭沁B續(xù)模型,其并沒有明確的層級(jí)數(shù),因此我們只能使用相似的度量確定模型的「深度」,作者在這篇論文中采用 ODE Solver 評(píng)估的次數(shù)作為深度。
其次,深度與誤差控制有著直接的聯(lián)系,ODEnet 通過控制誤差容忍度能確定模型的深度。因?yàn)?ODE Solver 能確保在誤差容忍度之內(nèi)逼近常微分方程的真實(shí)解,改變誤差容忍度就能改變神經(jīng)網(wǎng)絡(luò)的行為。一般而言,降低 ODE Solver 的誤差容忍度將增加函數(shù)的評(píng)估的次數(shù),因此類似于增加了模型的「深度」。調(diào)整誤差容忍度能允許我們在準(zhǔn)確度與計(jì)算成本之間做權(quán)衡,因此我們在訓(xùn)練時(shí)可以采用高準(zhǔn)確率而學(xué)習(xí)更好的神經(jīng)網(wǎng)絡(luò),在推斷時(shí)可以根據(jù)實(shí)際計(jì)算環(huán)境調(diào)整為較低的準(zhǔn)確度。
如原論文的上圖所示,a 圖表示模型能保證在誤差范圍為內(nèi),且隨著誤差降低,前向傳播的函數(shù)評(píng)估數(shù)增加。b 圖展示了評(píng)估數(shù)與相對(duì)計(jì)算時(shí)間的關(guān)系。d 圖展示了函數(shù)評(píng)估數(shù)會(huì)隨著訓(xùn)練的增加而自適應(yīng)地增加,這表明隨著訓(xùn)練的進(jìn)行,模型的復(fù)雜度會(huì)增加。
c 圖比較有意思,它表示前向傳播的函數(shù)評(píng)估數(shù)大致是反向傳播評(píng)估數(shù)的一倍,這恰好表示反向傳播中的 adjoint sensitivity 方法不僅內(nèi)存效率高,同時(shí)計(jì)算效率也比直接通過積分器的反向傳播高。這主要是因?yàn)?adjoint sensitivity 并不需要依次傳遞到前向傳播中的每一個(gè)函數(shù)評(píng)估,即梯度不通過模型的深度由后向前一層層傳。
反向傳播
師從同門的 Jesse Bettencourt 向機(jī)器之心介紹道,「天琦最擅長的就是耐心講解。」當(dāng)他遇到任何無論是代碼問題,理論問題還是數(shù)學(xué)問題,一旦是問了同桌的天琦,對(duì)方就一定會(huì)慢慢地花時(shí)間把問題講清楚、講透徹。而 ODEnet 的反向傳播,就是這樣一種需要耐心講解的問題。
ODEnet 的反向傳播與常見的反向傳播有一些不同,我們可能需要仔細(xì)查閱原論文與對(duì)應(yīng)的附錄證明才能有較深的理解。此外,作者給出了 ODEnet 的 PyTorch 實(shí)現(xiàn),我們也可以通過它了解實(shí)現(xiàn)細(xì)節(jié)。
正如作者而言,訓(xùn)練一個(gè)連續(xù)層級(jí)網(wǎng)絡(luò)的主要技術(shù)難點(diǎn)在于令梯度穿過 ODE Solver 的反向傳播。其實(shí)如果令梯度沿著前向傳播的計(jì)算路徑反傳回去是非常直觀的,但是內(nèi)存占用會(huì)比較大而且數(shù)值誤差也不能控制。作者的解決方案是將前向傳播的 ODE Solver 視為一個(gè)黑箱操作,梯度很難或根本不需要傳遞進(jìn)去,只需要「繞過」就行了。
作者采用了一種名為 adjoint method 的梯度計(jì)算方法來「繞過」前向傳播中的 ODE Solver,即模型在反傳中通過第二個(gè)增廣 ODE Solver 算出梯度,其可以逼近按計(jì)算路徑從 ODE Solver 傳遞回的梯度,因此可用于進(jìn)一步的參數(shù)更新。這種方法如上圖 c 所示不僅在計(jì)算和內(nèi)存非常有優(yōu)勢,同時(shí)還能精確地控制數(shù)值誤差。
具體而言,若我們的損失函數(shù)為 L(),且它的輸入為 ODE Solver 的輸出:
我們第一步需要求 L 對(duì) z(t) 的導(dǎo)數(shù),或者說模型損失的變化如何取決于隱藏狀態(tài) z(t) 的變化。其中損失函數(shù) L 對(duì) z(t_1) 的導(dǎo)數(shù)可以為整個(gè)模型的梯度計(jì)算提供入口。作者將這一個(gè)導(dǎo)數(shù)稱為 adjoint a(t) = -dL/z(t),它其實(shí)就相當(dāng)于隱藏層的梯度。
在基于鏈?zhǔn)椒▌t的傳統(tǒng)反向傳播中,我們需要從后一層對(duì)前一層求導(dǎo)以傳遞梯度。而在連續(xù)化的 ODEnet 中,我們需要將前面求出的 a(t) 對(duì)連續(xù)的 t 進(jìn)行求導(dǎo),由于 a(t) 是損失 L 對(duì)隱藏狀態(tài) z(t) 的導(dǎo)數(shù),這就和傳統(tǒng)鏈?zhǔn)椒▌t中的傳播概念基本一致。下式展示了 a(t) 的導(dǎo)數(shù),它能將梯度沿著連續(xù)的 t 向前傳,附錄 B.1 介紹了該式具體的推導(dǎo)過程。
在獲取每一個(gè)隱藏狀態(tài)的梯度后,我們可以再求它們對(duì)參數(shù)的導(dǎo)數(shù),并更新參數(shù)。同樣在 ODEnet 中,獲取隱藏狀態(tài)的梯度后,再對(duì)參數(shù)求導(dǎo)并積分后就能得到損失對(duì)參數(shù)的導(dǎo)數(shù),這里之所以需要求積分是因?yàn)椤笇蛹?jí)」t 是連續(xù)的。這一個(gè)方程式可以表示為:
綜上,我們對(duì) ODEnet 的反傳過程主要可以直觀理解為三步驟,即首先求出梯度入口伴隨 a(t_1),再求 a(t) 的變化率 da(t)/dt,這樣就能求出不同時(shí)刻的 a(t)。最后借助 a(t) 與 z(t),我們可以求出損失對(duì)參數(shù)的梯度,并更新參數(shù)。當(dāng)然這里只是簡要的直觀理解,更完整的反傳過程展示在原論文的算法 1。
反向傳播怎么做
在算法 1 中,陳天琦等研究者展示了如何借助另一個(gè) OED Solver 一次性求出反向傳播的各種梯度和更新量。要理解算法 1,首先我們要熟悉 ODESolver 的表達(dá)方式。例如在 ODEnet 的前向傳播中,求解過程可以表示為 ODEsolver(z(t_0), f, t_0, t_1, θ),我們可以理解為從 t_0 時(shí)刻開始令 z(t_0) 以變化率 f 進(jìn)行演化,這種演化即 f 在 t 上的積分,ODESolver 的目標(biāo)是通過積分求得 z(t_1)。
同樣我們能以這種方式理解算法 1,我們的目的是利用 ODESolver 從 z(t_1) 求出 z(t_0)、從 a(t_1) 按照方程 4 積出 a(t_0)、從 0 按照方程 5 積出 dL/dθ。最后我們只需要使用 dL/dθ 更新神經(jīng)網(wǎng)絡(luò) f(z(t), t, θ) 就完成了整個(gè)反向傳播過程。
如上所示,若初始給定參數(shù)θ、前向初始時(shí)刻 t_0 和終止時(shí)刻 t_1、終止?fàn)顟B(tài) z(t_1) 和梯度入口 ?L/?z(t_1)。接下來我們可以將三個(gè)積分都并在一起以一次性解出所有量,因此我們可以定義初始狀態(tài) s_0,它們是解常微分方程的初值。
注意第一個(gè)初值 z(t_1),其實(shí)在前向傳播中,從 z(t_0) 到 z(t_1) 都已經(jīng)算過一遍了,但是模型并不會(huì)保留計(jì)算結(jié)果,因此也就只有常數(shù)級(jí)的內(nèi)存成本。此外,在算 a(t) 時(shí)需要知道對(duì)應(yīng)的 z(t),例如 ?L/?z(t_0) 就要求知道 z(t_0) 的值。如果我們不能保存中間狀態(tài)的話,那么也可以從 z(t_1) 到 z(t_0) 反向再算一遍中間狀態(tài)。這個(gè)計(jì)算過程和前向過程基本一致,即從 z(t_1) 開始以變化率 f 進(jìn)行演化而推出 z(t_0)。
定義 s_0 后,我們需要確定初始狀態(tài)都是怎樣「演化」到終止?fàn)顟B(tài)的,定義這些演化的即前面方程 (3)、(4) 和 (5) 的被積函數(shù),也就是算法 1 中 aug_dynamics() 函數(shù)所定義的。
其中 f(z(t), t, θ) 從 t_1 到 t_0 積出來為 z(t_0),這第一個(gè)常微分方程是為了給第二個(gè)提供條件。而-a(t)*?L/?z(t) 從 t_1 到 t_0 積出來為 a(t_0),它類似于傳統(tǒng)神經(jīng)網(wǎng)絡(luò)中損失函數(shù)對(duì)第一個(gè)隱藏層的導(dǎo)數(shù),整個(gè) a(t) 就相當(dāng)于隱藏層的梯度。只有獲取積分路徑中所有隱藏層的梯度,我們才有可能進(jìn)一步解出損失函數(shù)對(duì)參數(shù)的梯度。
因此反向傳播中的第一個(gè)和第二個(gè)常微分方程 都是為第三個(gè)微分方程提供條件,即 a(t) 和 z(t)。最后,從 t_1 到 t_0 積分 -a(t)*?f(z(t), t, θ)/?θ 就能求出 dL/dθ。只需要一個(gè)積分,我們不再一層層傳遞梯度并更新該層特定的參數(shù)。
如下偽代碼所示,完成反向傳播的步驟很簡單。先定義各變量演化的方法,再結(jié)合將其結(jié)合初始化狀態(tài)一同傳入 ODESolver 就行了。
????return[f,?-a*df/da,?-a*df/dθ]
[z0,?dL/dx,?dL/dθ]?=?
????????ODESolver([z(t1),?dL/dz(t),?0],?f_and_a,?t1,?t0)
連續(xù)型的歸一化流
這種連續(xù)型轉(zhuǎn)換有一個(gè)非常重要的屬性,即流模型中最基礎(chǔ)的變量代換定理可以便捷快速地計(jì)算得出。在論文的第四節(jié)中,作者根據(jù)這樣的推導(dǎo)結(jié)果構(gòu)建了一個(gè)新型可逆密度模型,它能克服 Glow 等歸一化流模型的缺點(diǎn),并直接通過最大似然估計(jì)訓(xùn)練。
變量代換定理
對(duì)于概率密度估計(jì)中的變量代換定理,我們可以從單變量的情況開始。若給定一個(gè)隨機(jī)變量 z 和它的概率密度函數(shù) z~π(z),我們希望使用映射函數(shù) x=f(z) 構(gòu)建一個(gè)新的隨機(jī)變量。函數(shù) f 是可逆的,即 z=g(x),其中 f 和 g 互為逆函數(shù)。現(xiàn)在問題是如何推斷新變量的未知概率密度函數(shù) p(x)?
通過定義,積分項(xiàng) ∫π(z)dz 表示無限個(gè)無窮小的矩形面積之和,其中積分元Δz 為積分小矩形的寬,小矩形在位置 z 的高為概率密度函數(shù) π(z) 定義的值。若使用 f^?1(x) 表示 f(x) 的逆函數(shù),當(dāng)我們替換變量的時(shí)候,z=f^?1(x) 需要服從 Δz/Δx=(f^?1(x))′。多變量的變量代換定理可以從單變量推廣而出,其中 det ?f/?z 為函數(shù) f 的雅可比行列式:
一般使用變量代換定理需要計(jì)算雅可比矩陣?f/?z 的行列式,這是主要的限制,最近的研究工作都在權(quán)衡歸一化流模型隱藏層的表達(dá)能力與計(jì)算成本。但是研究者發(fā)現(xiàn),將離散的層級(jí)替換為連續(xù)的轉(zhuǎn)換,可以簡化計(jì)算,我們只需要算雅可比矩陣的跡就行了。核心的定理 1 如下所示:
在普通的變量代換定理中,分布的變換函數(shù) f(或神經(jīng)網(wǎng)絡(luò))必須是可逆的,而且要制作可逆的神經(jīng)網(wǎng)絡(luò)也很復(fù)雜。在陳天琦等研究者定理里,不論 f 是什么樣的神經(jīng)網(wǎng)絡(luò)都沒問題,它天然可逆,所以這種連續(xù)化的模型對(duì)流模型的應(yīng)用應(yīng)該非常方便。
如下所示,隨機(jī)變量 z(t_0) 及其分布可以通過一個(gè)連續(xù)的轉(zhuǎn)換演化到 z(t_1) 及其分布:
此外,連續(xù)型流模型還有很多性質(zhì)與優(yōu)勢,但這里并不展開。變量代換定理 1 在附錄 A 中有完整的證明,感興趣的讀者可查閱原論文了解細(xì)節(jié)。
最后,神經(jīng)常微分方程是一種全新的框架,除了流模型外,很多方法在連續(xù)變換的改變下都有新屬性,這些屬性可能在離散激活的情況下很難獲得。也許未來會(huì)有很多的研究關(guān)注這一新模型,連續(xù)化的神經(jīng)網(wǎng)絡(luò)也會(huì)變得多種多樣。
本文為機(jī)器之心原創(chuàng),轉(zhuǎn)載請(qǐng)聯(lián)系本公眾號(hào)獲得授權(quán)。
?------------------------------------------------
加入機(jī)器之心(全職記者 / 實(shí)習(xí)生):hr@jiqizhixin.com
投稿或?qū)で髨?bào)道:content@jiqizhixin.com
廣告 & 商務(wù)合作:bd@jiqizhixin.com
總結(jié)
以上是生活随笔為你收集整理的硬核NeruIPS 2018最佳论文,一个神经了的常微分方程的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 继BERT之后,这个新模型再一次在11项
- 下一篇: 深度学习时代的图模型,清华发文综述图网络