深度学习之循环神经网络(10)GRU简介
深度學習之循環神經網絡(10)GRU簡介
- 1. 復位門
- 2. 更新門
- 3. GRU使用方法
?LSTM具有更長的記憶能力,在大部分序列任務上面都取得了比基礎RNN模型更好的性能表現,更重要的是,LSTM不容易出現梯度彌散現象。但是LSTM結構相對較復雜,計算代價較高,模型參數量較大。因此科學家們嘗試簡化LSTM內部的計算流程,特別是減少門控數量。研究發現,遺忘門是LSTM中最重要的門控 [1],甚至發現只有遺忘門的簡化版網絡在多個基準數據集上面優于標準LSTM網絡。在眾多的簡化版LSTM中, 門控循環網絡(Gated Recurrent Unit,簡稱GRU)是應用最廣泛的RNN變種之一。GRU把內部狀態向量和輸出向量合并,統一為狀態向量 h\boldsymbol hh,門控數量也較少到2個: 復位門(Reset Gate)和 更新門(Update Gate),如下圖所示:
GRU網絡結構
?下面我們來分別介紹復位門和更新門的原理與功能。
[1] J. Westhuizen 和 J. Lasenby, “The unreasonable effectiveness of the forget gate,” CoRR, 卷 abs/1804.04849, 2018.
1. 復位門
?復位門用于控制上一個時間戳的狀態ht?1\boldsymbol h_{t-1}ht?1?進入GRU的量。門控向量gr\boldsymbol g_rgr?由當前時間戳輸入xt\boldsymbol x_txt?和上一時間戳狀態ht?1\boldsymbol h_{t-1}ht?1?變換得到,關系如下:
gr=σ(Wr[ht?1,xt]+br)\boldsymbol g_r=σ(\boldsymbol W_r [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_r)gr?=σ(Wr?[ht?1?,xt?]+br?)
其中Wr\boldsymbol W_rWr?和br\boldsymbol b_rbr?為復位門的參數,由反向傳播算法自動優化,σσσ為激活函數,一般使用Sigmoid函數。門控向量gr=0\boldsymbol g_r=0gr?=0時,新輸入h~t\tilde \boldsymbol h_th~t?全部來自于輸入xt\boldsymbol x_txt?,不接受ht?1\boldsymbol h_{t-1}ht?1?,此時相當于復位ht?1\boldsymbol h_{t-1}ht?1?。當gr=1\boldsymbol g_r=1gr?=1時,ht?1h_{t-1}ht?1?和輸入xt\boldsymbol x_txt?共同產生新輸入h~t\tilde\boldsymbol h_th~t?,如下圖所示:
2. 更新門
?更新門用控制上一時間戳狀態ht?1\boldsymbol h_{t-1}ht?1?和新輸入h~t\tilde\boldsymbol h_th~t?對新狀態向量ht\boldsymbol h_tht?的影響程度。更新門控向量gz\boldsymbol g_zgz?由
gz=σ(Wz[ht?1,xt]+bz)\boldsymbol g_z=σ(\boldsymbol W_z [\boldsymbol h_{t-1},\boldsymbol x_t ]+\boldsymbol b_z)gz?=σ(Wz?[ht?1?,xt?]+bz?)
得到,其中Wz\boldsymbol W_zWz?和bz\boldsymbol b_zbz?為更新門的參數,由反向傳播算法自動優化,σσσ為激活函數,一般使用Sigmoid函數。gz\boldsymbol g_zgz?用于控制新輸入h~t\tilde\boldsymbol h_th~t?信號,1?gz1-\boldsymbol g_z1?gz?用于控制狀態ht?1\boldsymbol h_{t-1}ht?1?信號:
ht=(1?gz)ht?1+gzh~t\boldsymbol h_t=(1-\boldsymbol g_z ) \boldsymbol h_{t-1}+\boldsymbol g_z \tilde\boldsymbol h_tht?=(1?gz?)ht?1?+gz?h~t?
可以看到,h~t\tilde\boldsymbol h_th~t?和ht?1\boldsymbol h_{t-1}ht?1?的更新量處于相互競爭、此消彼長的狀態。當更新門gz=0\boldsymbol g_z=0gz?=0時,ht\boldsymbol h_tht?全部來自上一時間戳狀態ht?1\boldsymbol h_{t-1}ht?1?;當更新門gz=1\boldsymbol g_z=1gz?=1時,ht\boldsymbol h_tht?全部來自新輸入h~t\tilde\boldsymbol h_th~t?。
3. GRU使用方法
?同樣地,在TensorFlow中,也有Cell方式和層方式實現GRU網絡。GRUCell和GRU層的使用方法和之前的SimpleRNNCell、LSTMCell、SimpleRNN和LSTM非常類似。首先是GRUCell的使用,創建GRU Cell對象,并在時間軸上循環展開運算。例如:
import tensorflow as tf from tensorflow.keras import layersx = tf.random.normal([2, 80, 100]) xt = x[:, 0, :] # 得到一個時間戳的輸入 # 初始化狀態向量,GRU只有一個 h = [tf.zeros([2, 64])] cell = layers.GRUCell(64) # 新建GRU Cell,向量長度為64 # 在時間戳維度上解開,循環通過cell for xt in tf.unstack(x, axis=1):out, h = cell(xt, h) # 輸出形狀 print(out.shape)
運行結果如下所示:
?通過layers.GRU類可以方便創建一層GRU網絡層,通過Sequential容器可以堆疊多層GRU層的網絡。例如:
運行結果如下所示:
總結
以上是生活随笔為你收集整理的深度学习之循环神经网络(10)GRU简介的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习之循环神经网络(9)LSTM层使
- 下一篇: CSS中怎么设置Checkbox复选框控