encoder decoder 模型理解
生活随笔
收集整理的這篇文章主要介紹了
encoder decoder 模型理解
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
encoder decoder 模型是比較難理解的,理解這個模型需要清楚lstm 的整個源碼細節,坦率的說這個模型我看了近十天,不敢說完全明白。
- 我把細胞的有絲分裂的圖片放在開頭,我的直覺細胞的有絲分裂和這個模型有相通之處
定義訓練編碼器
######################################################################################################################################################### # 定義訓練編碼器 #None表示可以處理任意長度的序列 # num_encoder_tokens表示特征的數目,三維張量的列數encoder_inputs = Input(shape=(None, num_encoder_tokens),name='encoder_inputs') # 編碼器,要求其返回狀態,lstm 公式理解https://blog.csdn.net/qq_38210185/article/details/79376053 encoder = LSTM(latent_dim, return_state=True,name='encoder_LSTM') # 編碼器的特征維的大小latent_dim,即單元數,也可以理解為lstm的層數#lstm 的輸出狀態,隱藏狀態,候選狀態 encoder_outputs, state_h, state_c = encoder(encoder_inputs) # 取出輸入生成的隱藏狀態和細胞狀態,作為解碼器的隱藏狀態和細胞狀態的初始化值。#上面兩行那種寫法很奇怪,看了幾天沒看懂,可以直接這樣寫 #encoder_outputs, state_h, state_c= LSTM(latent_dim, return_state=True,name='encoder_LSTM')(encoder_inputs)# 我們丟棄' encoder_output ',只保留隱藏狀態,候選狀態 encoder_states = [state_h, state_c] #########################################################################################################################################################定義訓練解碼器
######################################################################################################################################################### # 定義解碼器的輸入 # 同樣的,None表示可以處理任意長度的序列 # 設置解碼器,使用' encoder_states '作為初始狀態 # num_decoder_tokens表示解碼層嵌入長度,三維張量的列數 decoder_inputs = Input(shape=(None, num_decoder_tokens),name='decoder_inputs') # 接下來建立解碼器,解碼器將返回整個輸出序列 # 并且返回其中間狀態,中間狀態在訓練階段不會用到,但是在推理階段將是有用的 # 因解碼器用編碼器的隱藏狀態和細胞狀態,所以latent_dim必等 decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True,name='decoder_LSTM') # 將編碼器輸出的狀態作為初始解碼器的初始狀態 decoder_outputs, _, _ = decoder_lstm(decoder_inputs,initial_state=encoder_states)# 添加全連接層 # 這個full層在后面推斷中會被共享!! decoder_dense = Dense(num_decoder_tokens, activation='softmax',name='softmax') decoder_outputs = decoder_dense(decoder_outputs) #########################################################################################################################################################定義訓練模型
# Define the model that will turn # `encoder_input_data` & `decoder_input_data` into `decoder_target_data` #################################################################################################################### model = Model([encoder_inputs, decoder_inputs], decoder_outputs) #原始模型 #################################################################################################################### #model.load_weights('s2s.h5') # Run training model.compile(optimizer='rmsprop', loss='categorical_crossentropy') model.fit([encoder_input_data, decoder_input_data], decoder_target_data,batch_size=batch_size,epochs=epochs,validation_split=0.2) # 保存模型 model.save('s2s.h5')顯示模型的拓撲圖
import netron netron.start("s2s.h5")
編碼器,狹隘的說就是怎么將字符編碼,注意輸出是 encoder_states = [state_h, state_c] , 根據輸入序列得到隱藏狀態和候選門狀態,輸出是一個二元列表
# 定義推斷編碼器 根據輸入序列得到隱藏狀態和細胞狀態的路徑圖,得到模型,使用的輸入到輸出之間所有層的權重,與tf的預測簽名一樣 #################################################################################################################### encoder_model = Model(encoder_inputs, encoder_states) #編碼模型 ,注意輸出是 encoder_states = [state_h, state_c] #################################################################################################################### encoder_model.save('encoder_model.h5') import netron netron.start('encoder_model.h5')注意輸出是 encoder_states = [state_h, state_c] =[encoder_LSTM:1,: encoder_LSTM:2]
解碼模型
#解碼的隱藏層 decoder_state_input_h = Input(shape=(latent_dim,)) #解碼的候選門 decoder_state_input_c = Input(shape=(latent_dim,)) #解碼的輸入狀態 decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]decoder_outputsd, state_hd, state_cd = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)decoder_statesd = [state_hd, state_cd] decoder_outputsd1 = decoder_dense(decoder_outputsd) decoder_model = Model([decoder_inputs] + decoder_states_inputs,[decoder_outputsd] + decoder_statesd)解碼模型是一個三輸入,三輸出的模型
我們可以想一想為什么要用下面這段代碼
當輸入input_seq 之后,就得到了encoder_states ,在之后一直共享這個數值
然后就像解紐扣那樣,先找到第一個紐扣,就是’\t’,在target_text中’\t’就是第一個字符
【’\t’,states_value】–>下一個字符 c1
【‘c1’,states_value】–>下一個字符 c2
while循環一直到最后
代碼在git
總結
以上是生活随笔為你收集整理的encoder decoder 模型理解的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 使用pytorch动手实现LSTM模块
- 下一篇: 解决Ubuntu spyder 无法输入