论文阅读笔记《Meta-Learning with Memory-Augmented Neural Networks》
小樣本學習&元學習經典論文整理||持續更新
核心思想
??本文提出一種帶有記憶增強神經網絡(Memory-Augmented Neural Networks,MANN)的元學習算法用于解決小樣本學習問題。我們知道LSTM能夠通過遺忘門有選擇的保留部分先前樣本的信息(長期記憶),也可以通過輸入門獲得當前樣本的信息(短期記憶),這一記憶的方式是利用權重的更新隱式實現的。而在本文中,作者希望利用外部的內存空間顯式地記錄一些信息,使其結合神經網絡自身具備的長期記憶能力共同實現小樣本學習任務。
??如圖(a)所示,整個訓練過程分成多個Episode,每個Episode中包含若干個樣本 x x x和對應的標簽 y y y,將所有的樣本組合成一個序列, x t x_t xt?表示在 t t t時刻輸入的樣本, y t y_t yt?表示與之對應的標記,但要注意的是輸入時 x t x_t xt?和 y t y_t yt?并不是一一對應的,而是錯位對應,即 ( x t , y t ? 1 ) (x_t,y_{t-1}) (xt?,yt?1?)為 t t t時刻的輸入,這樣做的目的是讓網絡有目的的記住先前輸入的信息,因為只有保留有效的信息在下次再遇到同類樣本才能計算得到對應的損失。在每個Episode之間樣本序列都是被打亂的,這是為了避免網絡在訓練過程中慢慢記住了每個樣本對應的位置,這不是我們希望的。可以看到對于每個單元(圖中灰色的矩形塊)他的輸入信息既有前一個單元輸出的信息,又有當前輸入的信息,而輸出一方面要預測當前輸入樣本的類別,又要將信息傳遞給下個時刻的單元,這與LSTM或RNN很相似。在此基礎上作者增加了一個外部記憶模塊(如圖(b)中的藍色方框),他用來儲存在當前Eposide中所有"看過"的樣本的特征信息。怎樣去理解他呢?比如網絡第一次看到一張狗的照片,他并不能識別出它是什么,但是他把一些關鍵的特征信息記錄下來了,而且在下個時刻網絡得知了它的類別標簽是狗,此時網絡將特征信息與對應的標簽緊緊地聯系起來(Bind),當網絡下次看到狗的照片時,他用此時的特征信息與記憶模塊中儲存的特征信息進行匹配(Retrieve,真實的實現過程并不是匹配,而是通過回歸的方式獲取信息,此處只是方便大家理解),這樣就很容易知道這是一只狗了。這一過程其實與人類的學習模式非常接近了,但作者是如何利用神經網絡實現這一過程的呢?作者引入了神經圖靈機(NTM),為了方便下面的講解此處需要先介紹一下NTM。
??神經圖靈機的結構如上圖所示,它由控制器(Controller)和記憶模塊(Memory)構成,控制器利用寫頭(Write Heads)向記憶模塊中寫入信息,利用讀頭(Read Heads)從記憶模塊中讀取信息。回到本文的模型中,作者用一個LSTM或前向神經網絡作為控制器,用一個矩陣 M t M_t Mt?作為記憶模塊。給定一個輸入 x t x_t xt?,控制器輸出一個對應的鍵(Key) k t k_t kt?,可以理解為是一個特征向量,這個特征向量一方面要通過寫頭寫入記憶模塊,一方面又要通過讀頭匹配記憶模塊中的信息,用于完成分類任務或回歸任務。我們先介紹讀的過程,假設記憶模塊中已經儲存了許多的特征信息了,每個特征信息就是矩陣中的一行(特別注意,此處一行不是代表一個特征向量,而是某種抽象的特征。寫入的過程并不是將特征向量一行一行地堆放到記憶模塊中,寫入的過程遠比這個復雜),此時我們要計算當前特征向量 k t k_t kt?與記憶模塊 M t M_t Mt?中的各個向量之間的余弦距離 D ( k t , M t ( i ) ) D(k_t,M_t(i)) D(kt?,Mt?(i))(原文中用 K K K表示,為了避免與 k t k_t kt?混淆,特此改為 D D D),然后利用softmax函數將其轉化為讀取權重 w t r ( i ) w^r_t(i) wtr?(i),最后利用回歸的方式(加權求和)計算得到提取出來的記憶 r t = ∑ i w t r ( i ) M t ( i ) r_t=\sum_iw^r_t(i)M_t(i) rt?=∑i?wtr?(i)Mt?(i)。控制器一方面將 r t r_t rt?輸入到分類器(如softmax輸出層)中獲取當前樣本的類別,另一方面將其作為下一時刻控制器的一個輸入。
??寫的過程就是描述如何合理有效的將當前提取的特征信息存儲到記憶模塊中。作者采用了最少最近使用方法(Least Recently Used Access,LRUA),具體而言就是傾向于將特征信息存儲到使用次數較少的記憶矩陣位置,為了保護最近寫入的信息;或者寫入最近剛剛讀取過的記憶矩陣位置,因為相鄰兩個樣本之間可能存在一些相關信息。寫入的方法也是為記憶模塊中的每一行計算一個寫入權重 w t w ( i ) w^w_t(i) wtw?(i),然后將特征向量 k t k_t kt?乘以對應權重,在加上先前該位置保存的信息 M t ? 1 ( i ) M_{t-1}(i) Mt?1?(i)得到當前時刻的記憶矩陣 M t ( i ) = M t ? 1 ( i ) + w t w ( i ) k t M_t(i)=M_{t-1}(i)+w^w_t(i)k_t Mt?(i)=Mt?1?(i)+wtw?(i)kt?。而寫入權重 w t w w^w_t wtw?計算過程如下
其中 w t ? 1 r w^r_{t-1} wt?1r?表示上一時刻讀取權重,該值由讀的過程計算得到,權重越大表示上一時刻剛剛讀取過這一位置儲存的信息; σ ( ) \sigma() σ()表示sigmoid函數, α \alpha α表示一個門參數,用于控制兩個權重的比例。 w t ? 1 l u w^{lu}_{t-1} wt?1lu?表示上一時刻最少使用權重,其計算過程如下
其中, m ( w t u , n ) m(w_t^u,n) m(wtu?,n)表示向量 w t u w_t^u wtu?中第 n n n個最小的值, n n n表示內存讀取次數, w t u w_t^u wtu?表示使用權重,其計算過程如下
包含三個部分,上個時刻的使用權重 w t ? 1 u w_{t-1}^u wt?1u?, γ \gamma γ是衰減系數,讀取權重 w t r w^r_t wtr?和寫入權重 w t w w^w_t wtw?,當 w t u ( i ) w_t^u(i) wtu?(i)小于 m ( w t u , n ) m(w_t^u,n) m(wtu?,n)時表示位置 i i i是使用次數最少的位置之一,那么在下次寫入時,使用該位置的概率就更高。
??作者稱該模型是一種元學習算法,那是如何體現元學習過程的呢?我的理解是控制機本身是任務學習器(Learner),用于提取特征信息并預測分類,而整個模型則是一個元學習器(Meta-learner)用于學習如何將信息寫入/讀出記憶模塊。
實現過程
網絡結構&損失函數&訓練策略
??這部分內容論文中沒有特別具體的介紹,本身也不重要,核心在于整個模型的思想,具體的結構和損失函數可以結合任務需求自行選定。
網絡推廣
??該模型可以應用于分類和回歸任務。
創新點
- 設計了一種帶有記憶增強神經網絡的元學習算法,結合長期記憶和短期記憶兩方面優勢,能夠在看過某種類型的圖片一眼(one-shot),就能在下次遇到同類圖片時很快識別出來
- 利用神經圖靈機模型實現了記憶增強網絡,寫入的過程將特征信息與對應標簽緊密關聯起來,讀取的過程又將特征向量準確分類
算法評價
??該算法巧妙的將NTM應用于小樣本學習任務中,采用顯示的外部記憶模塊保留樣本特征信息,并利用元學習算法優化NTM的讀取和寫入過程,最終實現有效的小樣本分類和回歸。文中提到的長期記憶是通過控制器網絡權重參數的更新實現的,因為采用了錯位配對的方式,因此要到第二次見到該類別的圖像時才能得到相應的損失,并進行反向傳遞,因此權重更新過程是非常緩慢的,能夠保留很久之前的信息(如果權重更新速度很快,可能為了識別新的圖片類別,就迅速忘記之前識別過的圖片了)。短期記憶是由外部記憶模塊實現的,有人可能會覺得這個記憶模塊不是隨著訓練過程不斷儲存各個時刻的信息嗎?為什么叫做短期記憶呢?這是因為作者在兩個Eposide之間會清除記憶模塊,以避免兩個Eposide記憶之間相互干擾,而一個Eposide只是有若干個類別的少量樣本構成的,相對于整個學習過程他仍然屬于短期記憶。該算法整個思想都非常的新穎,NTM模型也十分的巧妙,作者自己也認為非常接近人類學習認知的模式了,但不知道是不是因為訓練比較困難的原因,該方法并沒有大規模的推廣。在學習該文章時,有必要提前了解一下NTM模型的原理,否則學習起來會比較困難。
如果大家對于深度學習與計算機視覺領域感興趣,希望獲得更多的知識分享與最新的論文解讀,歡迎關注我的個人公眾號“深視”。
總結
以上是生活随笔為你收集整理的论文阅读笔记《Meta-Learning with Memory-Augmented Neural Networks》的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 微信小程序自定义组件使用阿里矢量图标库图
- 下一篇: 在unbuntu16.04上安装网易云音