Neural Ordinary Differential Equations
神經常微分方程(2018)
- Abstract
- 1 Introduction
- 2 Reverse-mode automatic differentiation of ODE solutions(反向模式的自動微分ODE的解決方案)
- 3 Replacing residual networks with ODEs for supervised learning
- 4 Continuous Normalizing Flows
- 4.1 CNF試驗
Abstract
我們引入了一種新的深度神經網絡模型家族。我們沒有指定一個離散的隱藏層序列,而是使用神經網絡參數化隱藏狀態的導數。該網絡的輸出是用一個黑盒微分方程求解器來計算的。這些連續深度模型具有恒定的內存成本,可以根據每個輸入調整其評估策略,并可以明確地用數值精度換取速度。我們在連續深度剩余網絡和連續時間潛變量模型中證明了這些特性。我們還構造了連續的歸一化流,一個生成模型,可以通過最大似然進行訓練,而不需要對數據維進行劃分或排序。對于訓練,我們展示了如何通過任何ODE求解器可伸縮地反向傳播,而不訪問其內部操作。這允許在更大的模型中對ode進行端到端訓練。
1 Introduction
殘差網絡、循環神經網絡解碼器和正則化流等模型通過組合一系列到隱藏狀態的轉換來構建復雜的轉換:
ht+1=ht+f(θt,ht)(1)h_{t+1}=h_t+f(\theta_t,h_t)\tag{1}ht+1?=ht?+f(θt?,ht?)(1)其中t∈{1,...,T},ht∈Rdt\in\{1,...,T\},h_t\in\mathbb R^dt∈{1,...,T},ht?∈Rd,這些迭代更新可以看作是一個連續變換的歐拉離散化。
當我們添加更多的圖層和采取更小的步驟時,會發生什么?在極限情況下,我們使用一個由神經網絡指定的常微分方程(ODE)來參數化隱藏單元的連續動態:
dh(t)dt=f(h(t),t,θ)(2)\frac{dh(t)}{dt}=f(h(t),t,\theta)\tag{2}dtdh(t)?=f(h(t),t,θ)(2)從輸入層h(0)h(0)h(0),我們可以定義輸出層h(T)h(T)h(T)為在T時刻上ODE初值問題的解,這個值可以通過黑盒微分方程求解器計算,它評估隱藏神經元動力學f 在任何需要的地方求解符合精度要求的解。圖1對比了這兩種方法。
使用ODE求解器來定義和評估模型有幾個好處:
Memory efficiency: 在第2節中,我們展示了如何計算關于任何ODE求解器的所有輸入的標量值損失的梯度,而不通過求解器的操作反向傳播。如果不存儲任何中間數量的正向傳遞,我們就可以用恒定的內存成本作為深度的函數來訓練我們的模型,這是訓練深度模型的一個主要瓶頸。
Adaptive computation: 歐拉的方法可能是求解ode的最簡單的方法。從那以后,高效和精確的ODE求解器已經發展了120多年?,F代ODE求解器為近似誤差的增長提供了保證,監測誤差水平,并動態調整其評估策略,以達到所要求的精度水平。這使得評估模型的成本與問題的復雜性而變化。經過訓練后,實時或低功耗應用程序的精度可以降低。
Scalable and invertible normalizing flows: 連續變換的一個意想不到的副作用是,變量公式的變化變得更容易計算。在第4節中,我們推導了這個結果,并利用它來構造一類新的可逆密度模型,它避免了規范化流的單個單元瓶頸,并且可以直接通過最大似然進行訓練。
Continuous time-series models: 與需要離散觀測和發射間隔的循環神經網絡不同,連續定義的動態可以自然地合并在任意時間到達的數據。在第5節中,我們構建并演示了這樣一個模型。
2 Reverse-mode automatic differentiation of ODE solutions(反向模式的自動微分ODE的解決方案)
訓練連續深度網絡的主要技術困難是通過ODE求解器執行反向模式微分(也稱為反向傳播)。通過前向傳遞的操作進行區分是很簡單的,但會導致很高的內存成本,并引入額外的數值誤差。
我們將ODE求解器視為一個黑盒子,并使用伴隨靈敏度方法計算梯度(Pontryaginetal.,1962)。該方法通過求解第二個、增強了的時間向后(時間軸反向)的 ODE 來計算梯度,適用于所有的ODE求解器。這種方法與問題的大小成線性關系,內存成本低,并顯式地控制數值誤差。
考慮優化一個標量值損失函數L()L()L(),它的前向傳播過程可以如下表示:
z(t1)z(t_1)z(t1?)代表t1t_1t1?時刻的隱藏狀態,而當隱藏狀態被連續化后,t0t_0t0?到t1t_1t1?時刻的中間隱藏狀態的和就是等式中間部分的積分項。而整個前向過程可以用 ODE 求解器進行求解。
為了優化L,我們需要對θ求梯度。第一步就是要求L在每一個時刻對隱狀態z(t)的梯度,這個量被稱為伴隨矩陣a(t)=?L/?z(t)a(t)=?L/?z(t)a(t)=?L/?z(t)。它的動態過程被另一個 ODE 來求解,可以把這種瞬時性被看作鏈式法則:
da(t)dt=?a(t)T?f(z(t),θ,t)?z(4)\frac{da(t)}{dt}=-a(t)^T\frac{\partial f(z(t),\theta,t)}{\partial z}\tag{4}dtda(t)?=?a(t)T?z?f(z(t),θ,t)?(4)這樣, 再調一次求解器就可以解出?L/?z(t0)?L/?z(t_0)?L/?z(t0?)。
這個求解器從初始值?L/?z(t1)?L/?z(t_1)?L/?z(t1?)開始反向運行。一個復雜的問題是,解決這個ODE需要知道z(t)z(t)z(t)沿其整個軌跡的值。然而,我們可以簡單地從最終值z(t1)z(t_1)z(t1?)開始,將它的伴隨z(t)z(t)z(t)一起反向重新計算。
計算關于參數θ的梯度需要計算第三個積分,它同時取決于z(t)和a(t):
(4)和(5)中的a(t)T?f?za(t)^T\frac{\partial f}{\partial z}a(t)T?z?f?和a(t)T?f?θa(t)^T\frac{\partial f}{\partial \theta}a(t)T?θ?f?的vector-Jacobian products 都可以通過 ODE solver 快速求解, 所有的積分解z,a和?L/?θ?L/?θ?L/?θ都可以通過一個 ODE solver 來求解,可以將它們組合成一個向量解 (增強的狀態,augmented state)。算法1展示了如何構造必要的動態,并調用一個ODE求解器來一次計算所有的梯度。
大多數ODE求解器都可以選擇多次輸出狀態z(t)。當損失依賴于這些中間狀態時,反向偏導數必須被分解成一個單獨的解序列,在每對連續的輸出時間之間有一個解(圖2)。在每次觀測時,按相應的偏導數?L/?z(ti)?L/?z(t_i)?L/?z(ti?)方向調整。
由損失敏感度?L?z(tN)\frac{\partial L}{\partial {z(t_N)}}?z(tN?)?L?調節伴隨狀態a(t), 然后再有伴隨狀態 a(t) 得到損失敏感度?L?z(tN)\frac{\partial L}{\partial {z(t_N)}}?z(tN?)?L?。這是 ODE 反向的鏈式過程。至此,模擬了整個反向傳播的過程
3 Replacing residual networks with ODEs for supervised learning
在本節中,我們將實驗研究監督學習的神經ode的訓練。
Software 為了從數值上解決ODE初值問題,我們使用了在LSODE和VODE中實現的隱式Adams方法,并通過scipy。集成包進行接口。作為一種隱式方法,它比龍格-庫塔等顯式方法有更好的保證,但需要在每一步都要求解一個非線性優化問題。這種設置使得通過集成器的直接反向傳播變得困難。我們在Python的自動網格框架中實現了伴隨靈敏度方法(Maclaurinetal.,2015)。在本節的實驗中,我們使用張量流評估了GPU上的隱藏狀態動力學及其導數,然后從FortranODE求解器中調用,從Python自動grad代碼中調用。
Model Architectures 我們實驗了一個小殘差網絡,該網絡對輸入進行兩次降采樣,然后應用6個標準殘差塊He等人(2016b),它們被ODE-Net變體中的ODESolve模塊所取代。我們還測試了一個具有相同架構的網絡,但梯度直接通過龍格-庫塔積分器反向傳播,稱為RK-Net。表1顯示了測試誤差、參數數量和內存成本。L表示ResNet中的層數,L~\widetilde LL是ODE求解器在單個向前傳遞中請求的函數計算數,可以解釋為隱式的層數。我們發現ODE-Nets和RK-nets可以實現與ResNet幾乎相同的性能。
Error Control in ODE-Nets ODE求解器可以近似地確保輸出在真實解的給定容忍度范圍內。更改此公差會改變網絡的行為。我們首先驗證了在圖3a中確實可以控制錯誤。前向調用所花費的時間與函數評估的數量成正比(圖3b),因此調整公差給了我們一個在精度和計算成本之間的權衡。人們可以進行高精度的訓練,但在測試時會切換到較低的精度。
Network Depth目前尚不清楚如何定義ODE解決方案的“深度”。一個相關的數量是所需的隱藏狀態動態計算的數量,這個細節委托給ODE求解器,并依賴于初始狀態或輸入。圖3d顯示,在整個訓練過程中,功能評估的數量在訓練過程中不斷增加,這可能是為了適應模型不斷增加的復雜性。
4 Continuous Normalizing Flows
離散化的方程(1)也出現在規范化流(Rezende和Mohamed,2015)和NICE框架(Dinh等人,2014)中。這些方法利用變量變化定理來計算樣本通過雙射函數fff 進行變換時概率的精確變化:z1=f(z0)?log?p(z1)=log?p(z0)?log?∣det??f?z0∣\mathbf{z}_{1}=f\left(\mathbf{z}_{0}\right) \Longrightarrow \log p\left(\mathbf{z}_{1}\right)=\log p\left(\mathbf{z}_{0}\right)-\log \left|\operatorname{det} \frac{\partial f}{\partial \mathbf{z}_{0}}\right|z1?=f(z0?)?logp(z1?)=logp(z0?)?log∣∣?det?z0??f?∣∣?
經典的正則化流模型, planar normalization flows的公式如下:
z(t+1)=z(t)+uh(w?z(t)+b),log?p(z(t+1))=log?p(z(t))?log?∣1+u??h?z∣\mathbf{z}(t+1)=\mathbf{z}(t)+u h\left(w^{\top} \mathbf{z}(t)+b\right), \quad \log p(\mathbf{z}(t+1))=\log p(\mathbf{z}(t))-\log \left|1+u^{\top} \frac{\partial h}{\partial \mathbf{z}}\right|z(t+1)=z(t)+uh(w?z(t)+b),logp(z(t+1))=logp(z(t))?log∣∣?1+u??z?h?∣∣?
使用變量代換公式的瓶頸是計算雅克比矩陣。它的計算復雜度要么是z維度的立方, 要么是隱藏單元數量的立方,最近的研究都是在NF模型的表達能力和計算復雜度做取舍。
令人驚訝的是,從一組離散的層移動到一個連續的變換,簡化了規范化常數變化的計算。
定理1 變量瞬時變化定理
設z(t)z(t)z(t) 是一個有限連續隨機變量,概率p(z(t))p(z(t))p(z(t)) 依賴于時間. 則dzdt=f(z(t),t)\frac{dz}{dt}=f(z(t),t)dtdz?=f(z(t),t)是z(t)z(t)z(t) 隨時間連續變化的微分方程,假設fff 關于z一致LipschitzLipschitzLipschitz 連續,關于ttt 連續,那么對數概率密度的變化也遵循微分方程?log?p(z(t))?t=?tr?(dfdz(t))(8)\frac{\partial \log p(\mathbf{z}(t))}{\partial t}=-\operatorname{tr}\left(\frac{d f}{d \mathbf{z}(t)}\right)\tag{8}?t?logp(z(t))?=?tr(dz(t)df?)(8)
proofproofproof
為了證明這個定理,我們取了logp(z(t))logp(z(t))logp(z(t)) 隨時間的有限變化的無窮小極限。首先,我們表示zzz 對εεε 的時間變化的變換為z(t+?)=T?(z(t))(14)\mathbf z(t+\epsilon)=T_\epsilon(\mathbf z(t))\tag{14}z(t+?)=T??(z(t))(14)
我們假設fff 在z(t)z(t)z(t) 上是Lipschitz連續的,在t上是連續的,因此每個初值問題通過皮卡德存在性定理都有一個唯一解。我們還假設z(t)z(t)z(t) 是有界的。這些條件表明f,Tεf,T_εf,Tε? 和??zTε\frac{?}{?z}T_ε?z??Tε? 都是有界的。在下面,我們使用這些條件來交換極限和乘積。
我們利用用變量的離散變化公式表示微分方程?logp(z(t))?t\frac{?logp(z(t))}{?t}?t?logp(z(t))?,以及導數的定義:
行列式的導數可以用雅可比公式表示,則有
行列式求導公式d∣A∣dt=tr(A?dAdt)\frac{d|A|}{dt}=tr(A^*\frac{dA}{dt})dtd∣A∣?=tr(A?dtdA?)
用TεT_εTε?的泰勒級數展開式代替TεT_εTε?并取極限,完成了證明。
與(6)的logloglog 計算不同, 本式只需要計算跡(trace)的操作。另外, 不像標準的NF模型, 本式不要求f是可逆的, 因為如果滿足唯一性,那么整個轉換自然就是可逆的。
應用變量瞬時變化定理,我們可以看一下planar normalization flows的連續模擬版本:
dz(t)dt=uh(w?z(t)+b),?log?p(z(t))?t=?u??h?z(t)(9)\frac{d \mathbf{z}(t)}{d t}=u h\left(w^{\top} \mathbf{z}(t)+b\right), \quad \frac{\partial \log p(\mathbf{z}(t))}{\partial t}=-u^{\top} \frac{\partial h}{\partial \mathbf{z}(t)}\tag{9}dtdz(t)?=uh(w?z(t)+b),?t?logp(z(t))?=?u??z(t)?h?(9)
給定一個初始分布p(z(0)),我們可以從p(z(T))中采樣,并通過求解這組ODE來評估其概率密度。
使用多個線性成本的隱藏單元
當det不是線性方程時, 跡的方程還是線性的, 并且滿足tr(∑nJn=∑ntr(Jn))tr(\sum_{n}J_n=\sum_ntr(J_n))tr(∑n?Jn?=∑n?tr(Jn?)) ,這樣我們的方程就可以由一系列的求和得到, 概率密度的微分方程也是一個求和:
dz(t)dt=∑n=1Mfn(z(t)),dlog?p(z(t))dt=∑n=1Mtr?(?fn?z)(10)\frac{d \mathbf{z}(t)}{d t}=\sum_{n=1}^{M} f_{n}(\mathbf{z}(t)), \quad \frac{d \log p(\mathbf{z}(t))}{d t}=\sum_{n=1}^{M} \operatorname{tr}\left(\frac{\partial f_{n}}{\partial \mathbf{z}}\right)\tag{10}dtdz(t)?=n=1∑M?fn?(z(t)),dtdlogp(z(t))?=n=1∑M?tr(?z?fn??)(10)這意味著我們可以很簡便的評估多隱藏單元的流模型,其成本僅與隱藏單元M的數量呈線性關系。使用標準的NF模型評估這種“寬”層的成本是O(M3)O(M^3)O(M3),這意味著標準NF體系結構的多個層只使用單個隱藏單元.
依賴于時間的動態方程
我們可以將流的參數指定為t的函數,使微分方程f(z(t),t)f(z(t),t)f(z(t),t)隨ttt 而變化。這種參數化的方法是一種超網絡. 我們還為每個隱藏層引入了門機制,dzdt=∑nσn(t)fn(Z)\frac{d \mathbf{z}}{d t}=\sum_{n} \sigma_{n}(t) f_{n}(\mathbf{Z})dtdz?=∑n?σn?(t)fn?(Z) ,σn(t)∈(0,1)\sigma_n(t)\in (0,1)σn?(t)∈(0,1)是一個神經網絡, 可以學習到何時使用fn. 我們把該模型稱之為連續正則化流(CNF, continuous normalizing flows)
4.1 CNF試驗
我們首先比較連續的和離散的planar規范化流在學習一個已知的分布樣本。我們證明了一個具有M個隱藏單元的連續 planar CNF至少可以與一個具有K層(M = K)的離散 planar NF具有同樣的擬合能力,某些情況下CNF的擬合能力甚至更強。
擬合概率密度
設置一個前述的CNF, 用adam優化器訓練10000個step. 對應的NF使用RMSprop訓練500000個step. 此任務中損失函數為KL(q(x)∣∣p(x))KL (q(x)||p(x))KL(q(x)∣∣p(x)) , 最小化這個損失函數, 來用q(x)q(x)q(x) 擬合目標概率分布p(x)p(x)p(x) . 圖4表明, CNF可以得到更低的損失.
極大似然訓練
CNF一個有用的特性是: 計算反向轉換和正向的成本差不多, 這一點是NF模型做不到的. 這樣在用CNF模型做概率密度估計任務時, 我們可以通過極大似然估計來進行訓練 也就是最大化Ep(x)[log(q(x))]\mathbb E_{p(x)}[log(q(x))]Ep(x)?[log(q(x))] ,其中qqq 是變量代換之后的函數,然后反向轉換CNF來從q(x)q(x)q(x) 中進行采樣
對于這個實驗, 我們使用64個隱藏單元的CNF和64層的NF來進行對比,圖5展示了最終的訓練結果. 從最初的高斯分布, 到最終學到的分布, 每一個圖代表時間t的某一步. 有趣的是: 為了擬合兩個圓圈, CNF把planar 流 進行了旋轉, 這樣粒子會均分到兩個圓中. 跟 CNF的平滑可解釋相對的是, NF模型比較反直覺, 并且很難擬合雙月牙的概率分布(見圖5.b)
總結
以上是生活随笔為你收集整理的Neural Ordinary Differential Equations的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 苹果手机(ipone)点击元素,事件不执
- 下一篇: h5页面苹果手机不兼容普通点击事件