使用注意力机制建模 - 标准化日期格式
生活随笔
收集整理的這篇文章主要介紹了
使用注意力机制建模 - 标准化日期格式
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
文章目錄
- 1. 概述
- 2. 數(shù)據(jù)
- 3. 模型
- 4. 訓(xùn)練
- 5. 測試
參考 基于深度學(xué)習(xí)的自然語言處理
本文使用attention機(jī)制的模型,將各種格式的日期轉(zhuǎn)化成標(biāo)準(zhǔn)格式的日期
1. 概述
- LSTM、GRU 減少了梯度消失的問題,但是對(duì)于復(fù)雜依賴結(jié)構(gòu)的長句子,梯度消失仍然存在
- 注意力機(jī)制能同時(shí)看見句子中的每個(gè)位置,并賦予每個(gè)位置不同的權(quán)重(注意力),且可以并行計(jì)算
2. 數(shù)據(jù)
- 生成日期數(shù)據(jù)
- 生成日期數(shù)據(jù):隨機(jī)格式(X),標(biāo)準(zhǔn)格式(Y)
輸出:
- 建立字典,以及映射關(guān)系(字符 :idx)
- 日期(char序列)轉(zhuǎn) ids 序列,并且 pad / 截?cái)?/li>
- 根據(jù) ids 序列生成 one_hot 矩陣
檢查生成的 one_hot 編碼矩陣維度
輸出:
(10000, 30) (10000, 10) (10000, 30, 37) (10000, 10, 11)3. 模型
- softmax 激活函數(shù),求注意力權(quán)重
- 模型組件
- 模型
輸出:
Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_first (InputLayer) [(None, 30, 37)] 0 __________________________________________________________________________________________________ s0 (InputLayer) [(None, 64)] 0 __________________________________________________________________________________________________ bidirectional (Bidirectional) (None, 30, 64) 17920 input_first[0][0] __________________________________________________________________________________________________ repeat_vector (RepeatVector) (None, 30, 64) 0 s0[0][0] lstm[0][0] lstm[1][0] lstm[2][0] lstm[3][0] lstm[4][0] lstm[5][0] lstm[6][0] lstm[7][0] lstm[8][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 30, 128) 0 bidirectional[0][0] repeat_vector[0][0] bidirectional[0][0] repeat_vector[1][0] bidirectional[0][0] repeat_vector[2][0] bidirectional[0][0] repeat_vector[3][0] bidirectional[0][0] repeat_vector[4][0] bidirectional[0][0] repeat_vector[5][0] bidirectional[0][0] repeat_vector[6][0] bidirectional[0][0] repeat_vector[7][0] bidirectional[0][0] repeat_vector[8][0] bidirectional[0][0] repeat_vector[9][0] __________________________________________________________________________________________________ dense (Dense) (None, 30, 10) 1290 concatenate[0][0] concatenate[1][0] concatenate[2][0] concatenate[3][0] concatenate[4][0] concatenate[5][0] concatenate[6][0] concatenate[7][0] concatenate[8][0] concatenate[9][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 30, 1) 11 dense[0][0] dense[1][0] dense[2][0] dense[3][0] dense[4][0] dense[5][0] dense[6][0] dense[7][0] dense[8][0] dense[9][0] __________________________________________________________________________________________________ attention_weights (Activation) (None, 30, 1) 0 dense_1[0][0] dense_1[1][0] dense_1[2][0] dense_1[3][0] dense_1[4][0] dense_1[5][0] dense_1[6][0] dense_1[7][0] dense_1[8][0] dense_1[9][0] __________________________________________________________________________________________________ dot (Dot) (None, 1, 64) 0 attention_weights[0][0] bidirectional[0][0] attention_weights[1][0] bidirectional[0][0] attention_weights[2][0] bidirectional[0][0] attention_weights[3][0] bidirectional[0][0] attention_weights[4][0] bidirectional[0][0] attention_weights[5][0] bidirectional[0][0] attention_weights[6][0] bidirectional[0][0] attention_weights[7][0] bidirectional[0][0] attention_weights[8][0] bidirectional[0][0] attention_weights[9][0] bidirectional[0][0] __________________________________________________________________________________________________ c0 (InputLayer) [(None, 64)] 0 __________________________________________________________________________________________________ lstm (LSTM) [(None, 64), (None, 33024 dot[0][0] s0[0][0] c0[0][0] dot[1][0] lstm[0][0] lstm[0][2] dot[2][0] lstm[1][0] lstm[1][2] dot[3][0] lstm[2][0] lstm[2][2] dot[4][0] lstm[3][0] lstm[3][2] dot[5][0] lstm[4][0] lstm[4][2] dot[6][0] lstm[5][0] lstm[5][2] dot[7][0] lstm[6][0] lstm[6][2] dot[8][0] lstm[7][0] lstm[7][2] dot[9][0] lstm[8][0] lstm[8][2] __________________________________________________________________________________________________ dense_2 (Dense) (None, 11) 715 lstm[0][0] lstm[1][0] lstm[2][0] lstm[3][0] lstm[4][0] lstm[5][0] lstm[6][0] lstm[7][0] lstm[8][0] lstm[9][0] ================================================================================================== Total params: 52,960 Trainable params: 52,960 Non-trainable params: 0 ________________________________________________________________________________________________4. 訓(xùn)練
from keras.optimizers import Adam # 優(yōu)化器 opt = Adam(learning_rate=0.005, decay=0.01) # 配置模型 model.compile(optimizer=opt, loss='categorical_crossentropy',metrics=['accuracy'])# 初始化 解碼器狀態(tài) s0 = np.zeros((m, n_s)) c0 = np.zeros((m, n_s)) outputs = list(Yoh.swapaxes(0, 1)) # Yoh shape 10000*10*11,調(diào)換0,1軸,為10*10000*11 # outputs list,長度 10, 每個(gè)里面是array 10000*11history = model.fit([Xoh, s0, c0], outputs,epochs=10, batch_size=128,validation_split=0.1)- 繪制 loss 和 各位置的準(zhǔn)確率
5. 測試
s0 = np.zeros((1, n_s)) c0 = np.zeros((1, n_s)) test_data,_,_,_ = load_dateset(10) for x,y in test_data:print(x + " ==> " +y) for x,_ in test_data:source = string_to_int(x, Tx, human_vocab)source = np.array(list(map(lambda a : to_categorical(a, num_classes=len(human_vocab)), source)))source = source[np.newaxis, :]pred = model.predict([source, s0, c0])pred = np.argmax(pred, axis=-1)output = [inv_machine_vocab[int(i)] for i in pred]print('source:',x)print('output:',''.join(output))輸出:
18 april 2014 ==> 2014-04-18 saturday august 22 1998 ==> 1998-08-22 october 22 1995 ==> 1995-10-22 thursday february 29 1996 ==> 1996-02-29 wednesday october 17 1979 ==> 1979-10-17 7 12 73 ==> 1973-12-07 9/30/01 ==> 2001-09-30 22 may 2001 ==> 2001-05-22 7 march 1979 ==> 1979-03-07 19 feb 2013 ==> 2013-02-19預(yù)測10個(gè),錯(cuò)誤了4個(gè),日期字符不完全正確
source: 18 april 2014 output: 2014-04-18 source: saturday august 22 1998 output: 1998-08-22 source: october 22 1995 output: 1995-12-22 # 錯(cuò)誤 10 月 source: thursday february 29 1996 output: 1996-02-29 source: wednesday october 17 1979 output: 1979-10-17 source: 7 12 73 output: 1973-02-07 # 錯(cuò)誤 12月 source: 9/30/01 output: 2001-05-00 # 錯(cuò)誤 09-30 source: 22 may 2001 output: 2011-05-22 # 錯(cuò)誤 2001 source: 7 march 1979 output: 1979-03-07 source: 19 feb 2013 output: 2013-02-19總結(jié)
以上是生活随笔為你收集整理的使用注意力机制建模 - 标准化日期格式的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LeetCode MySQL 1454.
- 下一篇: TensorFlow 2.0 - Hub