(pytorch-深度学习)通过时间反向传播
通過時間反向傳播
介紹循環神經網絡中梯度的計算和存儲方法,即通過時間反向傳播(back-propagation through time)。
- 正向傳播和反向傳播相互依賴。
- 正向傳播在循環神經網絡中比較直觀,而通過時間反向傳播其實是反向傳播在循環神經網絡中的具體應用。
- 我們需要將循環神經網絡按時間步展開,從而得到模型變量和參數之間的依賴關系,并依據鏈式法則應用反向傳播計算并存儲梯度。
定義模型
考慮一個簡單的無偏差項的循環神經網絡,且激活函數為恒等映射(?(x)=x\phi(x)=x?(x)=x)。設時間步 ttt 的輸入為單樣本 xt∈Rd\boldsymbol{x}_t \in \mathbb{R}^dxt?∈Rd,標簽為 yty_tyt?,那么隱藏狀態 ht∈Rh\boldsymbol{h}_t \in \mathbb{R}^hht?∈Rh的計算表達式為
ht=Whxxt+Whhht?1,\boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1}, ht?=Whx?xt?+Whh?ht?1?,
其中Whx∈Rh×d\boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}Whx?∈Rh×d和Whh∈Rh×h\boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}Whh?∈Rh×h是隱藏層權重參數。設輸出層權重參數Wqh∈Rq×h\boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}Wqh?∈Rq×h,時間步ttt的輸出層變量ot∈Rq\boldsymbol{o}_t \in \mathbb{R}^qot?∈Rq計算為
ot=Wqhht.\boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}. ot?=Wqh?ht?.
設時間步ttt的損失為?(ot,yt)\ell(\boldsymbol{o}_t, y_t)?(ot?,yt?)。時間步數為TTT的損失函數LLL定義為
L=1T∑t=1T?(ot,yt).L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t). L=T1?t=1∑T??(ot?,yt?).
將LLL稱為有關給定時間步的數據樣本的目標函數。
模型計算圖
為了可視化循環神經網絡中模型變量和參數在計算中的依賴關系,我們可以繪制模型計算圖,像下圖。例如,時間步3的隱藏狀態h3\boldsymbol{h}_3h3?的計算依賴模型參數Whx\boldsymbol{W}_{hx}Whx?、Whh\boldsymbol{W}_{hh}Whh?、上一時間步隱藏狀態h2\boldsymbol{h}_2h2?以及當前時間步輸入x3\boldsymbol{x}_3x3?。
表示了時間步數為3的循環神經網絡模型計算中的依賴關系。
- 方框代表變量(無陰影)或參數(有陰影),圓圈代表運算符
方法
圖中的模型的參數是 Whx\boldsymbol{W}_{hx}Whx?, Whh\boldsymbol{W}_{hh}Whh? 和 Wqh\boldsymbol{W}_{qh}Wqh?。訓練模型通常需要模型參數的梯度?L/?Whx\partial L/\partial \boldsymbol{W}_{hx}?L/?Whx?、?L/?Whh\partial L/\partial \boldsymbol{W}_{hh}?L/?Whh?和?L/?Wqh\partial L/\partial \boldsymbol{W}_{qh}?L/?Wqh?。 圖中的依賴關系,我們可以按照其中箭頭所指的反方向依次計算并存儲梯度。
- 首先,目標函數有關各時間步輸出層變量的梯度?L/?ot∈Rq\partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q?L/?ot?∈Rq很容易計算:
?L?ot=??(ot,yt)T??ot.\frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}.?ot??L?=T??ot???(ot?,yt?)?.
- 之后,可以計算目標函數有關模型參數Wqh\boldsymbol{W}_{qh}Wqh?的梯度?L/?Wqh∈Rq×h\partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}?L/?Wqh?∈Rq×h。根據計算圖,LLL通過o1,…,oT\boldsymbol{o}_1, \ldots, \boldsymbol{o}_To1?,…,oT?依賴Wqh\boldsymbol{W}_{qh}Wqh?。依據鏈式法則,
?L?Wqh=∑t=1Tprod(?L?ot,?ot?Wqh)=∑t=1T?L?otht?.\frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top. ?Wqh??L?=t=1∑T?prod(?ot??L?,?Wqh??ot??)=t=1∑T??ot??L?ht??.
- 其次,隱藏狀態之間也存在依賴關系。 在計算圖中,LLL只通過oT\boldsymbol{o}_ToT?依賴最終時間步TTT的隱藏狀態hT\boldsymbol{h}_ThT?。因此,我們先計算目標函數有關最終時間步隱藏狀態的梯度?L/?hT∈Rh\partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h?L/?hT?∈Rh。依據鏈式法則,我們得到
?L?hT=prod(?L?oT,?oT?hT)=Wqh??L?oT.\frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}. ?hT??L?=prod(?oT??L?,?hT??oT??)=Wqh???oT??L?.
-
接下來對于時間步t<Tt < Tt<T, 在計算圖中,LLL通過ht+1\boldsymbol{h}_{t+1}ht+1?和ot\boldsymbol{o}_tot?依賴ht\boldsymbol{h}_tht?。依據鏈式法則, 目標函數有關時間步t<Tt < Tt<T的隱藏狀態的梯度?L/?ht∈Rh\partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h?L/?ht?∈Rh需要按照時間步從大到小依次計算:
?L?ht=prod(?L?ht+1,?ht+1?ht)+prod(?L?ot,?ot?ht)=Whh??L?ht+1+Wqh??L?ot\frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}{t+1}}, \frac{\partial \boldsymbol{h}{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t} ?ht??L?=prod(?ht+1?L?,?ht??ht+1?)+prod(?ot??L?,?ht??ot??)=Whh???ht+1??L?+Wqh???ot??L? -
將上面的遞歸公式展開,對任意時間步1≤t≤T1 \leq t \leq T1≤t≤T,我們可以得到目標函數有關隱藏狀態梯度的通項公式
?L?ht=∑i=tT(Whh?)T?iWqh??L?oT+t?i.\frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}. ?ht??L?=i=t∑T?(Whh??)T?iWqh???oT+t?i??L?.
由上式中的指數項可見,當時間步數 TTT 較大或者時間步 ttt 較小時,目標函數有關隱藏狀態的梯度較容易出現衰減和爆炸。這也會影響其他包含?L/?ht\partial L / \partial \boldsymbol{h}_t?L/?ht?項的梯度,例如隱藏層中模型參數的梯度?L/?Whx∈Rh×d\partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}?L/?Whx?∈Rh×d和?L/?Whh∈Rh×h\partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}?L/?Whh?∈Rh×h。 在圖中,LLL通過h1,…,hT\boldsymbol{h}_1, \ldots, \boldsymbol{h}_Th1?,…,hT?依賴這些模型參數。 依據鏈式法則,有
?L?Whx=∑t=1Tprod(?L?ht,?ht?Whx)=∑t=1T?L?htxt?,?L?Whh=∑t=1Tprod(?L?ht,?ht?Whh)=∑t=1T?L?htht?1?.\begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned} ?Whx??L???Whh??L??=t=1∑T?prod(?ht??L?,?Whx??ht??)=t=1∑T??ht??L?xt??,=t=1∑T?prod(?ht??L?,?Whh??ht??)=t=1∑T??ht??L?ht?1??.?
每次迭代中,我們在依次計算完以上各個梯度后,會將它們存儲起來,從而避免重復計算。
- 例如,由于隱藏狀態梯度?L/?ht\partial L/\partial \boldsymbol{h}_t?L/?ht?被計算和存儲,之后的模型參數梯度?L/?Whx\partial L/\partial \boldsymbol{W}_{hx}?L/?Whx?和?L/?Whh\partial L/\partial \boldsymbol{W}_{hh}?L/?Whh?的計算可以直接讀取?L/?ht\partial L/\partial \boldsymbol{h}_t?L/?ht?的值,而無須重復計算它們。
- 此外,反向傳播中的梯度計算可能會依賴變量的當前值。它們正是通過正向傳播計算出來的。 舉例來說,參數梯度?L/?Whh\partial L/\partial \boldsymbol{W}_{hh}?L/?Whh?的計算需要依賴隱藏狀態在時間步t=0,…,T?1t = 0, \ldots, T-1t=0,…,T?1的當前值ht\boldsymbol{h}_tht?(h0\boldsymbol{h}_0h0?是初始化得到的)。這些值是通過從輸入層到輸出層的正向傳播計算并存儲得到的。
總結
以上是生活随笔為你收集整理的(pytorch-深度学习)通过时间反向传播的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 刚一下雪,中国就美哭了全世界!
- 下一篇: 毕业大论文到底怎么写?