对seq2seq的一些个人理解
對seq2seq的一些個人理解
原創?2017年05月10日 11:43:25因為做畢設用到seq2seq框架,網上關于seq2seq的資料很多,但關于seq2seq的代碼則比較少,閱讀tensorflow的源碼則需要跳來跳去比較麻煩(其實就是博主懶)。踩了很多坑后,形成了一些個人的理解,在這里記錄下,如果有人恰好路過,歡迎指出錯誤~
seq2seq圖解如下:?
上圖中,C是encoder輸出的最終狀態,作為decoder的初始狀態;W是encoder的最終輸出,作為decoder的初始輸入。
具體到tensorflow代碼中(tensorflow r1.1.0cpu版本),查閱tf.contrib.rnn.BasicLSTMCell的源碼如下:
class BasicLSTMCell(RNNCell):def __init__(self, num_units, forget_bias=1.0,input_size=None, state_is_tuple=True, activation=tanh,reuse=None):super(BasicLSTMCell, self).__init__(_reuse=reuse)if not state_is_tuple:logging.warn("%s: Using a concatenated state is slower and will soon be deprecated. Use state_is_tuple=True.", self)if input_size is not None:logging.warn("%s: The input_size parameter is deprecated.", self)self._num_units = num_unitsself._forget_bias = forget_biasself._state_is_tuple = state_is_tupleself._activation = activation@propertydef state_size(self):return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units)@propertydef output_size(self):return self._num_unitsdef call(self, inputs, state):"""Long short-term memory cell (LSTM)."""# Parameters of gates are concatenated into one multiply for efficiency.if self._state_is_tuple:c, h = stateelse:c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)concat = _linear([inputs, h], 4 * self._num_units, True)# i = input_gate, j = new_input, f = forget_gate, o = output_gatei, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))new_h = self._activation(new_c) * sigmoid(o)if self._state_is_tuple:new_state = LSTMStateTuple(new_c, new_h)else:new_state = array_ops.concat([new_c, new_h], 1)return new_h, new_state- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
令調用LSTM的命令為:
output,state = tf.contrib.rnn.BasicLSTMCell(input,init_state)- 1
可知,state其實是包含了output在內的。state[0]才是真正的state,即圖中的C;state[1]是output,即圖中的W。這樣一來,最后輸出的output其實就顯得雞肋了。(如果要在encode和decode之間搞事情的話,這點就比較重要了。博主就是踩了這個坑。。。當然如果不在這里搞事情的話就可以完美繞過這個坑)
知道這點后,那么接下來的就好理解多了。博主之前曾有過一段時間的疑惑,那就是seq2seq的decode_input到底是什么?如果跟target只是移了一個位,其他完全不變的話,那要encoder干什么?知道了上面的背景后,我們不難知道,教程中decode_input跟target的移位只是加速訓練過程。而在具體應用中,decode_input可以是encode的最后一個輸出,也可以自己設定一個全零的數組。個人覺得設定全零的數組比較好,因為初始狀態就已經包含了encode的最后一個輸出了,而且全零數組可以當作是一個開始的標識(至于seq2seq具體的訓練過程可視化,可以閱讀2017年ACL的一篇文章Visualizing and Understanding Neural Machine Translation?http://nlp.csai.tsinghua.edu.cn/~ly/papers/acl2017_dyz.pdf)
最后,還說幾點比較零散的:?
1、對于短句(<30詞),可以不進行輸入翻轉,模型收斂地稍微慢一點而已;對于長句則最好進行翻轉?
2、多閱讀教程,多實踐。上手操作永遠是學習的最佳途徑
總結
以上是生活随笔為你收集整理的对seq2seq的一些个人理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 谷歌开源 tf-seq2seq,你也能用
- 下一篇: tensorflow中的seq2seq例