小样本论文笔记5:Model Based - [6] One-shot learning with memory-augmented neural networks.
小樣本論文筆記5:Model Based - [6] One-shot learning with memory-augmented neural networks
文章目錄
- 小樣本論文筆記5:Model Based - [6] One-shot learning with memory-augmented neural networks
- 0. 前言
- 1. 要解決什么問(wèn)題
- 2. 用了什么方法
- 2.1 模型結(jié)構(gòu)
- 2.2 NTM&MANN
- 2.3 Least Recently Used Access 最少最近使用原則
- 3. 效果如何
- 3.1 數(shù)據(jù)集
- 3.2 Omniglot 分類(lèi)
- 3.3 回歸-擬合從未見(jiàn)過(guò)的函數(shù)
- 4. 還存在什么問(wèn)題&有什么可以借鑒
0. 前言
- 相關(guān)資料:
- 論文地址
- github
- 筆記參考
- 論文基本信息
- 領(lǐng)域:小樣本學(xué)習(xí)
- 作者單位:Google DeepMind
- 發(fā)表期刊和時(shí)間:ICML2016
- 一句話(huà)總結(jié)
- 提出了 Memory-Augmented Neural Network(MANN)結(jié)構(gòu),達(dá)到了,,,效果。
1. 要解決什么問(wèn)題
使用帶有記憶功能的神經(jīng)網(wǎng)絡(luò)Memory-Augmented Neural Network(MANN)元學(xué)習(xí)算法解決小樣本學(xué)習(xí)問(wèn)題。一個(gè)尺度可變的網(wǎng)絡(luò)需要符合兩點(diǎn)需求:
1、信息必須以特征方式存儲(chǔ)在內(nèi)存中,且具有穩(wěn)定和易以元素形式訪(fǎng)問(wèn)的性質(zhì);
2、參數(shù)數(shù)量不應(yīng)該與內(nèi)存大小綁定
而Neurak Turing Machines(NTM)和內(nèi)存網(wǎng)絡(luò)符合這種要求。因此,作者采用MANN進(jìn)行長(zhǎng)期和短期記憶的元學(xué)習(xí)任務(wù)。
2. 用了什么方法
2.1 模型結(jié)構(gòu)
- 如圖(a)所示,整個(gè)訓(xùn)練過(guò)程分成多個(gè)Episode,每個(gè)Episode中包含若干個(gè)樣本 x x x和對(duì)應(yīng)的標(biāo)簽 y y y,將所有的樣本組合成一個(gè)序列, x t x_t xt?表示在 t t t時(shí)刻輸入的樣本, y t y_t yt?表示與之對(duì)應(yīng)的標(biāo)記,但要注意的是輸入時(shí) x t x_t xt?和 y t y_t yt?并不是一一對(duì)應(yīng)的,而是錯(cuò)位對(duì)應(yīng),即 ( x t , y t ? 1 ) (x_t,y_{t-1}) (xt?,yt?1?)為tt時(shí)刻的輸入,這樣做的目的是讓網(wǎng)絡(luò)有目的的記住先前輸入的信息,因?yàn)橹挥斜A粲行У男畔⒃谙麓卧儆龅酵?lèi)樣本才能計(jì)算得到對(duì)應(yīng)的損失。在每個(gè)Episode之間樣本序列都是被打亂的,這是為了避免網(wǎng)絡(luò)在訓(xùn)練過(guò)程中慢慢記住了每個(gè)樣本對(duì)應(yīng)的位置,這不是我們希望的??梢钥吹綄?duì)于每個(gè)單元(圖中灰色的矩形塊)他的輸入信息既有前一個(gè)單元輸出的信息,又有當(dāng)前輸入的信息,而輸出一方面要預(yù)測(cè)當(dāng)前輸入樣本的類(lèi)別,又要將信息傳遞給下個(gè)時(shí)刻的單元,這與LSTM或RNN很相似。
- 在此基礎(chǔ)上作者增加了一個(gè)外部記憶模塊(如圖(b)中的藍(lán)色方框),他用來(lái)儲(chǔ)存在當(dāng)前Eposide中所有"看過(guò)"的樣本的特征信息。怎樣去理解他呢?比如網(wǎng)絡(luò)第一次看到一張狗的照片,他并不能識(shí)別出它是什么,但是他把一些關(guān)鍵的特征信息記錄下來(lái)了,而且在下個(gè)時(shí)刻網(wǎng)絡(luò)得知了它的類(lèi)別標(biāo)簽是狗,此時(shí)網(wǎng)絡(luò)將特征信息與對(duì)應(yīng)的標(biāo)簽緊緊地聯(lián)系起來(lái)(Bind),當(dāng)網(wǎng)絡(luò)下次看到狗的照片時(shí),他用此時(shí)的特征信息與記憶模塊中儲(chǔ)存的特征信息進(jìn)行匹配(Retrieve,真實(shí)的實(shí)現(xiàn)過(guò)程并不是匹配,而是通過(guò)回歸的方式獲取信息,此處只是方便大家理解),這樣就很容易知道這是一只狗了。這一過(guò)程其實(shí)與人類(lèi)的學(xué)習(xí)模式非常接近了,但作者是如何利用神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)這一過(guò)程的呢?作者引入了神經(jīng)圖靈機(jī)(NTM),為了方便下面的講解此處需要先介紹一下NTM。
2.2 NTM&MANN
-
神經(jīng)圖靈機(jī)的結(jié)構(gòu)如上圖所示,它由控制器(Controller)和記憶模塊(Memory)構(gòu)成,控制器利用寫(xiě)頭(Write Heads)向記憶模塊中寫(xiě)入信息,利用讀頭(Read Heads)從記憶模塊中讀取信息。
-
回到本文的模型中(如上圖),作者用一個(gè)LSTM或前向神經(jīng)網(wǎng)絡(luò)作為控制器,用一個(gè)矩陣 M t M_t Mt?作為記憶模塊。給定一個(gè)輸入 x t x_t xt?,控制器輸出一個(gè)對(duì)應(yīng)的鍵(Key) k t k_t kt?,可以理解為是一個(gè)特征向量,這個(gè)特征向量一方面要通過(guò)寫(xiě)頭寫(xiě)入記憶模塊,一方面又要通過(guò)讀頭匹配記憶模塊中的信息,用于完成分類(lèi)任務(wù)或回歸任務(wù)。我們先介紹讀的過(guò)程,假設(shè)記憶模塊中已經(jīng)儲(chǔ)存了許多的特征信息了,每個(gè)特征信息就是矩陣中的一行(特別注意,此處一行不是代表一個(gè)特征向量,而是某種抽象的特征。寫(xiě)入的過(guò)程并不是將特征向量一行一行地堆放到記憶模塊中,寫(xiě)入的過(guò)程遠(yuǎn)比這個(gè)復(fù)雜),此時(shí)我們要計(jì)算當(dāng)前特征向量 k t k_t kt?與記憶模塊 M t M_t Mt?中的各個(gè)向量之間的余弦距離 D ( k t , M t ( i ) ) D(k_t,M_t(i)) D(kt?,Mt?(i))(原文中用 K K K表示,為了避免與 k t k_t kt?混淆,特此改為 D D D),然后利用softmax函數(shù)將其轉(zhuǎn)化為讀取權(quán)重 w t r ( i ) w^r_t(i) wtr?(i),最后利用回歸的方式(加權(quán)求和)計(jì)算得到提取出來(lái)的記憶 r t = ∑ i w t r ( i ) M t ( i ) r_t=\sum_iw^r_t(i)M_t(i) rt?=∑i?wtr?(i)Mt?(i)。控制器一方面將 r t r_t rt?輸入到分類(lèi)器(如softmax輸出層)中獲取當(dāng)前樣本的類(lèi)別,另一方面將其作為下一時(shí)刻控制器的一個(gè)輸入。
2.3 Least Recently Used Access 最少最近使用原則
- 寫(xiě)的過(guò)程就是描述如何合理有效的將當(dāng)前提取的特征信息存儲(chǔ)到記憶模塊中。作者采用了最少最近使用方法(Least Recently Used Access,LRUA),具體而言就是傾向于將特征信息存儲(chǔ)到使用次數(shù)較少的記憶矩陣位置,為了保護(hù)最近寫(xiě)入的信息;或者寫(xiě)入最近剛剛讀取過(guò)的記憶矩陣位置,因?yàn)橄噜弮蓚€(gè)樣本之間可能存在一些相關(guān)信息。寫(xiě)入的方法也是為記憶模塊中的每一行計(jì)算一個(gè)寫(xiě)入權(quán)重 w t w ( i ) w^w_t(i) wtw?(i),然后將特征向量 k t k_t kt?乘以對(duì)應(yīng)權(quán)重,在加上先前該位置保存的信息 M t ? 1 ( i ) M_{t-1}(i) Mt?1?(i)得到當(dāng)前時(shí)刻的記憶矩陣 M t ( i ) = M t ? 1 ( i ) + w t w ( i ) k t M_t(i)=M_{t-1}(i)+w^w_t(i)k_t Mt?(i)=Mt?1?(i)+wtw?(i)kt?。而寫(xiě)入權(quán)重 w t w w^w_t wtw?計(jì)算過(guò)程如下:
w t w ← σ ( α ) w t ? 1 r + ( 1 ? σ ( α ) ) w t ? 1 l u \mathbf{w}_{t}^{w} \leftarrow \sigma(\alpha) \mathbf{w}_{t-1}^{r}+(1-\sigma(\alpha)) \mathbf{w}_{t-1}^{l u} wtw?←σ(α)wt?1r?+(1?σ(α))wt?1lu?
w t ? 1 r \mathbf{w}_{t-1}^{r} wt?1r? 表示上一時(shí)刻的讀取權(quán)重,該值由讀的過(guò)程計(jì)算得到,權(quán)重越大表示上一時(shí)刻剛剛讀取過(guò)這一位置儲(chǔ)存的信息; σ ( ) \sigma() σ()表示sigmoid函數(shù), α \alpha α表示一個(gè)門(mén)參數(shù),用于控制兩個(gè)權(quán)重的比例。 w t ? 1 l u w^{lu}_{t-1} wt?1lu?表示上一時(shí)刻最少使用權(quán)重,其計(jì)算過(guò)程如下:
w t l u ( i ) = { 0 if? w t u ( i ) > m ( w t u , n ) 1 if? w t u ( i ) ≤ m ( w t u , n ) w_{t}^{l u}(i)=\left\{\begin{array}{ll} 0 & \text { if } w_{t}^{u}(i)>m\left(\mathbf{w}_{t}^{u}, n\right) \\ 1 & \text { if } w_{t}^{u}(i) \leq m\left(\mathbf{w}_{t}^{u}, n\right) \end{array}\right. wtlu?(i)={01??if?wtu?(i)>m(wtu?,n)?if?wtu?(i)≤m(wtu?,n)?
其中, m ( w t u , n ) m(w_t^u,n) m(wtu?,n)表示向量 w t u w_t^u wtu?中第 n n n個(gè)最小的值, n n n表示內(nèi)存讀取次數(shù), w t u w_t^u wtu?表示使用權(quán)重,其計(jì)算過(guò)程如下
w t u ← γ w t ? 1 u + w t r + w t w \mathbf{w}_{t}^{u} \leftarrow \gamma \mathbf{w}_{t-1}^{u}+\mathbf{w}_{t}^{r}+\mathbf{w}_{t}^{w} wtu?←γwt?1u?+wtr?+wtw?
包含三個(gè)部分,上個(gè)時(shí)刻的使用權(quán)重 w t ? 1 u w_{t-1}^u wt?1u?, γ \gamma γ是衰減系數(shù),讀取權(quán)重 w t r w^r_t wtr?和寫(xiě)入權(quán)重 w t w w^w_t wtw?,當(dāng) w t u ( i ) w_t^u(i) wtu?(i)小于 m ( w t u , n ) m(w_t^u,n) m(wtu?,n)時(shí)表示位置 i i i是使用次數(shù)最少的位置之一,那么在下次寫(xiě)入時(shí),使用該位置的概率就更高。
根據(jù)寫(xiě)入權(quán)重更新記憶矩陣:
M t ( i ) ← M t ? 1 ( i ) + w t w ( i ) k t , ? i \mathbf{M}_{t}(i) \leftarrow \mathbf{M}_{t-1}(i)+w_{t}^{w}(i) \mathbf{k}_{t}, \forall i Mt?(i)←Mt?1?(i)+wtw?(i)kt?,?i
因此,內(nèi)存可以寫(xiě)入零內(nèi)存槽或先前使用的槽;如果是后者,那么最少使用的內(nèi)存就會(huì)被刪除。
3. 效果如何
3.1 數(shù)據(jù)集
- 數(shù)據(jù)集有兩種:
- 用于分類(lèi)任務(wù)和Omniglot;
- 從固定參數(shù)的高斯過(guò)程采樣函數(shù),用于回歸;
對(duì)Omniglot數(shù)據(jù)集進(jìn)行了旋轉(zhuǎn)和變換等數(shù)據(jù)增強(qiáng)操作,利用旋轉(zhuǎn)新增了一些類(lèi)別;最終1200類(lèi)用于訓(xùn)練,423類(lèi)用于測(cè)試。并且對(duì)圖片縮放到20*20.
3.2 Omniglot 分類(lèi)
-
在訓(xùn)練10萬(wàn)個(gè)Episode(包括5個(gè)隨機(jī)選取的類(lèi)別和5個(gè)隨機(jī)選取的標(biāo)簽???類(lèi)別和標(biāo)簽對(duì)應(yīng)嗎?),對(duì)模型用一系列的測(cè)試episode進(jìn)行測(cè)試。在該過(guò)程中,沒(méi)有進(jìn)行更多的訓(xùn)練,并且模型從未見(jiàn)過(guò)Omniglot測(cè)試集中的類(lèi)別。模型表現(xiàn)出了高分類(lèi)精度:在樣本第二次被輸入模型時(shí),就獲得了在每個(gè)episode上82.8%的準(zhǔn)確率,第五次出現(xiàn)時(shí),精度94.9%,第10次出現(xiàn)時(shí)則表現(xiàn)出98.1%的精度(所以,在測(cè)試過(guò)程中記憶矩陣也是會(huì)改變的?所以才能記住測(cè)試集中樣本的特征,從而在下次“看見(jiàn)”同類(lèi)樣本時(shí)才能“認(rèn)出來(lái)”嗎?)
與人類(lèi)表現(xiàn)相比,模型效果要優(yōu)于人類(lèi)對(duì)新樣本的識(shí)別能力。具體對(duì)人類(lèi)的測(cè)試過(guò)程如下。參與者的任務(wù)細(xì)節(jié):- 1、對(duì)于一張圖片,他們必須從1到5選擇一個(gè)合適的數(shù)字標(biāo)簽。
- 2、然后,那張圖片再次出現(xiàn),參與者需要根據(jù)圖片的類(lèi)別標(biāo)簽做一個(gè)不計(jì)時(shí)的預(yù)測(cè)。
- 3、然后,圖片消失,受試者被給予關(guān)于圖片標(biāo)簽正確性與否的反饋。
- 4、正確的標(biāo)簽會(huì)被展示無(wú)論預(yù)測(cè)結(jié)果準(zhǔn)確度高低,目的是讓受試者進(jìn)一步確認(rèn)正確的預(yù)測(cè)內(nèi)容。
- 5、在一個(gè)短期的2s延時(shí)之后,一個(gè)新的圖片會(huì)出現(xiàn),然后重復(fù)這個(gè)過(guò)程。
并且,有趣的是,即使是MANN第一次碰到新樣本,它的分類(lèi)效果也好于“隨機(jī)瞎猜”。因?yàn)?#xff0c;它會(huì)根據(jù)“記憶”排斥一些確定的錯(cuò)誤選項(xiàng)。這個(gè)過(guò)程也跟人類(lèi)參與者的思考過(guò)程類(lèi)似。
-
表2是分別使用5類(lèi)和15類(lèi)分類(lèi)時(shí),MANN與其他結(jié)構(gòu)對(duì)比的分類(lèi)效果。KNN的效果其實(shí)還不錯(cuò)。無(wú)參,并且對(duì)記憶空間無(wú)限制。但MANN使用LRUA效果還是有效的,并且隨著次數(shù)的增多,精度逐步上升。
3.3 回歸-擬合從未見(jiàn)過(guò)的函數(shù)
- x值固定,是數(shù)據(jù)樣本,y值是函數(shù)值;(a)是MANN經(jīng)過(guò)20個(gè)樣本之后對(duì)x的預(yù)測(cè)值;(b)是GP產(chǎn)生的函數(shù)值。
4. 還存在什么問(wèn)題&有什么可以借鑒
- 該算法巧妙的將NTM應(yīng)用于小樣本學(xué)習(xí)任務(wù)中,采用顯示的外部記憶模塊保留樣本特征信息,并利用元學(xué)習(xí)算法優(yōu)化NTM的讀取和寫(xiě)入過(guò)程,最終實(shí)現(xiàn)有效的小樣本分類(lèi)和回歸。文中提到的長(zhǎng)期記憶是通過(guò)控制器網(wǎng)絡(luò)權(quán)重參數(shù)的更新實(shí)現(xiàn)的,因?yàn)椴捎昧隋e(cuò)位配對(duì)的方式,因此要到第二次見(jiàn)到該類(lèi)別的圖像時(shí)才能得到相應(yīng)的損失,并進(jìn)行反向傳遞,因此權(quán)重更新過(guò)程是非常緩慢的,能夠保留很久之前的信息(如果權(quán)重更新速度很快,可能為了識(shí)別新的圖片類(lèi)別,就迅速忘記之前識(shí)別過(guò)的圖片了)。短期記憶是由外部記憶模塊實(shí)現(xiàn)的,有人可能會(huì)覺(jué)得這個(gè)記憶模塊不是隨著訓(xùn)練過(guò)程不斷儲(chǔ)存各個(gè)時(shí)刻的信息嗎?為什么叫做短期記憶呢?這是因?yàn)樽髡咴趦蓚€(gè)Eposide之間會(huì)清除記憶模塊,以避免兩個(gè)Eposide記憶之間相互干擾,而一個(gè)Eposide只是有若干個(gè)類(lèi)別的少量樣本構(gòu)成的,相對(duì)于整個(gè)學(xué)習(xí)過(guò)程他仍然屬于短期記憶。該算法整個(gè)思想都非常的新穎,NTM模型也十分的巧妙,作者自己也認(rèn)為非常接近人類(lèi)學(xué)習(xí)認(rèn)知的模式了,**但不知道是不是因?yàn)橛?xùn)練比較困難的原因,該方法并沒(méi)有大規(guī)模的推廣。**在學(xué)習(xí)該文章時(shí),有必要提前了解一下NTM模型的原理,否則學(xué)習(xí)起來(lái)會(huì)比較困難。
總結(jié)
以上是生活随笔為你收集整理的小样本论文笔记5:Model Based - [6] One-shot learning with memory-augmented neural networks.的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: python开发部署时新增数据库中表的方
- 下一篇: RPG Maker