使用RNN预测文档归属作者
生活随笔
收集整理的這篇文章主要介紹了
使用RNN预测文档归属作者
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
文章目錄
- 1. 文本處理
- 2. 文本序列化
- 3. 數據集拆分
- 4. 建立RNN模型
- 5. 訓練
- 6. 測試
參考 基于深度學習的自然語言處理
1. 文本處理
數據預覽
- 把同一作者的文檔合并,去除\n, 多余空格,以及作者的名字(防止數據泄露)
2. 文本序列化
- 采用字符級別的 tokenizer char_level=True
- ids 序列切分成等長的子串樣本
3. 數據集拆分
- A、B數據集混合
- 訓練集,測試集拆分
4. 建立RNN模型
from keras.models import Sequential from keras.layers import SimpleRNN, Dense, EmbeddingEmbedding_dim = 128 # 輸出的嵌入的維度 RNN_size = 256 # RNN 單元個數model = Sequential() model.add(Embedding(input_dim=len(char_tokenizer.word_index)+1,output_dim=Embedding_dim,input_length=SEQ_LEN)) model.add(SimpleRNN(units=RNN_size, return_sequences=False)) # 只輸出最后一步 # return the last output in the output sequence model.add(Dense(1, activation='sigmoid')) # 二分類model.compile(optimizer='adam', loss='binary_crossentropy',metrics=['accuracy']) model.summary()模型結構:
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) (None, 30, 128) 6784 _________________________________________________________________ simple_rnn (SimpleRNN) (None, 256) 98560 _________________________________________________________________ dense (Dense) (None, 1) 257 ================================================================= Total params: 105,601 Trainable params: 105,601 Non-trainable params: 0 _________________________________________________________________如果return_sequences=True,后兩個輸出維度如下:(增加了序列長度維度)
simple_rnn_1 (SimpleRNN) (None, 30, 256) 98560 _________________________________________________________________ dense_1 (Dense) (None, 30, 1) 2575. 訓練
batch_size = 4096 # 一次梯度下降使用的樣本數量 epochs = 20 # 訓練輪數 history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs,validation_data=(X_test, y_test),verbose=1) Epoch 1/20 88/88 [==============================] - 59s 669ms/step - loss: 0.6877 - accuracy: 0.5436 - val_loss: 0.6856 - val_accuracy: 0.5540 Epoch 2/20 88/88 [==============================] - 56s 634ms/step - loss: 0.6830 - accuracy: 0.5564 - val_loss: 0.6844 - val_accuracy: 0.5550 Epoch 3/20 88/88 [==============================] - 56s 633ms/step - loss: 0.6825 - accuracy: 0.5577 - val_loss: 0.6829 - val_accuracy: 0.5563 Epoch 4/20 88/88 [==============================] - 56s 634ms/step - loss: 0.6816 - accuracy: 0.5585 - val_loss: 0.6788 - val_accuracy: 0.5641 Epoch 5/20 88/88 [==============================] - 56s 637ms/step - loss: 0.6714 - accuracy: 0.5813 - val_loss: 0.6670 - val_accuracy: 0.5877 Epoch 6/20 88/88 [==============================] - 56s 637ms/step - loss: 0.6532 - accuracy: 0.6113 - val_loss: 0.6435 - val_accuracy: 0.6235 Epoch 7/20 88/88 [==============================] - 57s 648ms/step - loss: 0.6287 - accuracy: 0.6424 - val_loss: 0.6159 - val_accuracy: 0.6563 Epoch 8/20 88/88 [==============================] - 55s 620ms/step - loss: 0.5932 - accuracy: 0.6807 - val_loss: 0.5747 - val_accuracy: 0.6971 Epoch 9/20 88/88 [==============================] - 54s 615ms/step - loss: 0.5383 - accuracy: 0.7271 - val_loss: 0.5822 - val_accuracy: 0.7178 Epoch 10/20 88/88 [==============================] - 56s 632ms/step - loss: 0.4803 - accuracy: 0.7687 - val_loss: 0.4536 - val_accuracy: 0.7846 Epoch 11/20 88/88 [==============================] - 61s 690ms/step - loss: 0.3979 - accuracy: 0.8190 - val_loss: 0.3940 - val_accuracy: 0.8195 Epoch 12/20 88/88 [==============================] - 60s 687ms/step - loss: 0.3257 - accuracy: 0.8572 - val_loss: 0.3248 - val_accuracy: 0.8564 Epoch 13/20 88/88 [==============================] - 59s 668ms/step - loss: 0.2637 - accuracy: 0.8897 - val_loss: 0.2980 - val_accuracy: 0.8742 Epoch 14/20 88/88 [==============================] - 56s 638ms/step - loss: 0.2154 - accuracy: 0.9115 - val_loss: 0.2326 - val_accuracy: 0.9023 Epoch 15/20 88/88 [==============================] - 56s 639ms/step - loss: 0.1822 - accuracy: 0.9277 - val_loss: 0.2112 - val_accuracy: 0.9130 Epoch 16/20 88/88 [==============================] - 56s 640ms/step - loss: 0.1504 - accuracy: 0.9412 - val_loss: 0.1803 - val_accuracy: 0.9267 Epoch 17/20 88/88 [==============================] - 58s 660ms/step - loss: 0.1298 - accuracy: 0.9499 - val_loss: 0.1662 - val_accuracy: 0.9331 Epoch 18/20 88/88 [==============================] - 57s 643ms/step - loss: 0.1132 - accuracy: 0.9567 - val_loss: 0.1643 - val_accuracy: 0.9358 Epoch 19/20 88/88 [==============================] - 58s 659ms/step - loss: 0.1018 - accuracy: 0.9613 - val_loss: 0.1409 - val_accuracy: 0.9441 Epoch 20/20 88/88 [==============================] - 57s 642ms/step - loss: 0.0907 - accuracy: 0.9659 - val_loss: 0.1325 - val_accuracy: 0.9475- 繪制訓練過程
6. 測試
# 測試for file in os.listdir('./papers/Unknown'):# 測試文本處理unk_file = preprocessing('./papers/Unknown/'+file)# 文本轉ids序列unk_file_seq = char_tokenizer.texts_to_sequences([unk_file])[0]# 提取固定長度的子串,形成多個樣本X_unk, _ = make_subsequence(unk_file_seq, UNKNOWN)# 預測y_pred = model.predict(X_unk)y_pred = y_pred > 0.5votesA = np.sum(y_pred==0)votesB = np.sum(y_pred==1)print("文章 {} 被預測為 {} 寫的,投票數 {} : {}".format(file,"A:hamilton" if votesA > votesB else "B:madison",max(votesA, votesB),min(votesA, votesB)))輸出:5個文本的作者,都預測對了
文章 paper_1.txt 被預測為 B:madison 寫的,投票數 12211 : 8563 文章 paper_2.txt 被預測為 B:madison 寫的,投票數 10899 : 8747 文章 paper_3.txt 被預測為 A:hamilton 寫的,投票數 7041 : 6343 文章 paper_4.txt 被預測為 A:hamilton 寫的,投票數 5063 : 4710 文章 paper_5.txt 被預測為 A:hamilton 寫的,投票數 6878 : 4876總結
以上是生活随笔為你收集整理的使用RNN预测文档归属作者的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LeetCode 1713. 得到子序列
- 下一篇: LeetCode MySQL 1321.