深度学习——RNN原理与TensorFlow2下的IMDB简单实践
在深度學習中,RNN是處理序列數據的有效方法之一,也是深度的一種很好的體現,本文將簡單介紹RNN的工作方式,以及針對IMDB數據集的簡單實踐
RNN簡介
RNN(Recurrent Neural Network),在基本的全連接層上迭代一層或多層帶有歷史信息(h)的RNN神經單元(RNN cell),使神經網絡能夠處理具有上下文關聯的序列數據,能夠有效減少隱層的參數量,提升訓練效率和準確率
為了更好的說明RNN的工作原理,我們帶入一個具體的目標,就是評價情感分析,如圖所示:
我們所要做的就是通過下方由單詞組成的評論來確定其情感是積極還是消極。我們把語句定義為x,輸出定義為y,輸出的結果即:P(y|x)
這里的embedding操作可以簡單理解為一個線性和,即
Oi=x@weighti+biasi
但這樣簡單的線性傳遞操作之后,只能通過每一個單詞的含義來判定情感,無法關聯到上下文,為了保存并處理上下文的語義,我們給線性操作附加一個歷史信息h。如果這樣處理,那我們完全可以省略掉針對每一個單詞不同的weight,而使用一個公共的weight用于單詞提取,稱為weightx,同理偏置稱為biasx,此時引入歷史信息h,初始化h0為全零,則公式修改為:
Oi=x@weightx+hi@weighth+biasx=hi+1
每一次計算的輸出和傳遞給下一層的歷史信息其實是相同的,這里分開來寫是為了下一篇LSTM留坑;而所謂的傳遞給下一層,實際上可以由同一個RNNcell迭代完成,這也是RNN名字的由來
說完了公式,我們回到神經網絡的根基,也就是梯度的求解
額外的參數定義:
t表示第t個句子,或者t時刻
激活函數——tan()
則:
ht=tan(x*weightx+ht-1@weighth)
yt=weighto*ht
這里我們忽略偏置
則損失函數的梯度由鏈式法則可以寫為:
第一個導數,由于損失函數和t時刻輸出yt是直接關聯的,因此第一個導數就是我們定義的損失函數對yt的直接求導,已知
第二個導數,當前時刻輸出yt對當前時刻歷史信息ht的導數在公式中可直接看出為weighto,已知
第三個導數,
令f=tanh(x),由ht公式可知
推導過程請自行演算
第四個導數,對tan激活函數求導后再對weighth 求導即可,已知
綜上可知,RNN梯度的復雜度需要對時間軸進行展開,復雜程度很高,因此需要用到TensorFlow等框架進行計算
IMDB數據集和RNN網絡的簡單實踐
對于數據集的加載可以直接使用TensorFlow2下的Keras中Dataset直接導入,如果下載速度很慢可能是因為……你懂得
total_words = 10000 (x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)接下來做數據預處理
max_review_len = 80 # x_train:[b, 80] # x_test: [b, 80] x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len) x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)訓練集和測試集構建
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train)) db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True) db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test)) db_test = db_test.batch(batchsz, drop_remainder=True)簡單起見,我們只設計一層的RNN網絡,自定義一個RNN網絡用于訓練
class MyRNN(keras.Model):def __init__(self, units):super(MyRNN, self).__init__()# [b, 64]self.state = [tf.zeros([batchsz, units])]# self.state1 = [tf.zeros([batchsz, units])]# transform text to embedding representation# [b, 80] => [b, 80, 100]self.embedding = layers.Embedding(total_words, embedding_len,input_length=max_review_len)# [b, 80, 100] , h_dim: 64# RNN: cell1 ,cell2, cell3# SimpleRNNself.rnn_cell = layers.SimpleRNNCell(units, dropout=0.2)# self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)# fc, [b, 80, 100] => [b, 64] => [b, 1]self.fc= layers.Dense(1)def call(self, inputs, training=None):# [b, 80]x = inputs# embedding: [b, 80] => [b, 80, 100]x = self.embedding(x)# rnn cell compute# [b, 80, 100] => [b, 64]state = self.state# state1 = self.state1for word in tf.unstack(x, axis=1): # word: [b, 100]# h1 = x*wxh+h0*whh# out: [b, 64]out, state = self.rnn_cell(word, state, training)# out: [b, 64] => [b, 1]x = self.fc(out)# p(y is pos|x)prob = tf.sigmoid(x)return prob然后使用TensorFlow2中的compile and fit功能即可實現訓練和測試,給出筆者的運行結果
整體來看運行的正確率達到82%,沒有達到很高的原因在于層數太少,僅僅簡單實現了一層的RNN網絡,同時可以發現筆者使用了隨機種子,這樣的隨機RNN如果更換成更加貼合數據的因子就能夠有所提升
以上就是全部內容,筆者目前研究生在讀,所了解到的知識有限,歡迎大佬們留言一起交流學習
與50位技術專家面對面20年技術見證,附贈技術全景圖總結
以上是生活随笔為你收集整理的深度学习——RNN原理与TensorFlow2下的IMDB简单实践的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: c语言中闰年的流程图_C语言-算法与流程
- 下一篇: java的debug模式_java第六章