详解seq2seq
1. 什么是seq2seq
在?然語?處理的很多應(yīng)?中,輸?和輸出都可以是不定?序列。以機(jī)器翻譯為例,輸?可以是?段不定?的英語?本序列,輸出可以是?段不定?的法語?本序列,例如:
英語輸?:“They”、“are”、“watching”、“.”
法語輸出:“Ils”、“regardent”、“.”
當(dāng)輸?和輸出都是不定?序列時(shí),我們可以使?編碼器—解碼器(encoder-decoder)或者seq2seq模型。序列到序列模型,簡稱seq2seq模型。這兩個(gè)模型本質(zhì)上都?到了兩個(gè)循環(huán)神經(jīng)?絡(luò),分別叫做編碼器和解碼器。編碼器?來分析輸?序列,解碼器?來?成輸出序列。兩 個(gè)循環(huán)神經(jīng)網(wǎng)絡(luò)是共同訓(xùn)練的。
下圖描述了使?編碼器—解碼器將上述英語句?翻譯成法語句?的?種?法。在訓(xùn)練數(shù)據(jù)集中,我們可以在每個(gè)句?后附上特殊符號(hào)“”(end of sequence)以表?序列的終?。編碼器每個(gè)時(shí)間步的輸?依次為英語句?中的單詞、標(biāo)點(diǎn)和特殊符號(hào)“”。下圖中使?了編碼器在 最終時(shí)間步的隱藏狀態(tài)作為輸?句?的表征或編碼信息。解碼器在各個(gè)時(shí)間步中使?輸?句?的 編碼信息和上個(gè)時(shí)間步的輸出以及隱藏狀態(tài)作為輸?。我們希望解碼器在各個(gè)時(shí)間步能正確依次 輸出翻譯后的法語單詞、標(biāo)點(diǎn)和特殊符號(hào)“”。需要注意的是,解碼器在最初時(shí)間步的輸? ?到了?個(gè)表?序列開始的特殊符號(hào)“”(beginning of sequence)。
2. 編碼器
編碼器的作?是把?個(gè)不定?的輸?序列變換成?個(gè)定?的背景變量 c,并在該背景變量中編碼輸?序列信息。常?的編碼器是循環(huán)神經(jīng)?絡(luò)。
讓我們考慮批量?小為1的時(shí)序數(shù)據(jù)樣本。假設(shè)輸?序列是 x1, . . . , xT,例如 xi 是輸?句?中的第 i 個(gè)詞。在時(shí)間步 t,循環(huán)神經(jīng)?絡(luò)將輸? xt 的特征向量 xt 和上個(gè)時(shí)間步的隱藏狀態(tài)ht?1ht?1變換為當(dāng)前時(shí)間步的隱藏狀態(tài)ht。我們可以?函數(shù) f 表達(dá)循環(huán)神經(jīng)?絡(luò)隱藏層的變換:
ht=f(xt,ht?1)ht=f(xt,ht?1)
接下來,編碼器通過?定義函數(shù) q 將各個(gè)時(shí)間步的隱藏狀態(tài)變換為背景變量:
c=q(h1,…,hT)c=q(h1,…,hT)
例如,當(dāng)選擇 q(***h*1, . . . , ***h***T ) = ***h***T 時(shí),背景變量是輸?序列最終時(shí)間步的隱藏狀態(tài)***h***T。
以上描述的編碼器是?個(gè)單向的循環(huán)神經(jīng)?絡(luò),每個(gè)時(shí)間步的隱藏狀態(tài)只取決于該時(shí)間步及之前的輸??序列。我們也可以使?雙向循環(huán)神經(jīng)?絡(luò)構(gòu)造編碼器。在這種情況下,編碼器每個(gè)時(shí)間步的隱藏狀態(tài)同時(shí)取決于該時(shí)間步之前和之后的?序列(包括當(dāng)前時(shí)間步的輸?),并編碼了整個(gè)序列的信息。
3. 解碼器
剛剛已經(jīng)介紹,編碼器輸出的背景變量 c 編碼了整個(gè)輸?序列 x1, . . . , xT 的信息。給定訓(xùn)練樣本中的輸出序列 y1, y2, . . . , yT′ ,對(duì)每個(gè)時(shí)間步 t′(符號(hào)與輸?序列或編碼器的時(shí)間步 t 有區(qū)別),解碼器輸出 yt′ 的條件概率將基于之前的輸出序列 y1,…,yt′?1y1,…,yt′?1 和背景變量 c,即:
P(yt′|y1,…,yt′?1,c)P(yt′|y1,…,yt′?1,c)
為此,我們可以使?另?個(gè)循環(huán)神經(jīng)?絡(luò)作為解碼器。在輸出序列的時(shí)間步 t′,解碼器將上?時(shí)間步的輸出 yt′?1yt′?1 以及背景變量 c 作為輸?,并將它們與上?時(shí)間步的隱藏狀態(tài) st′?1st′?1 變換為當(dāng)前時(shí)間步的隱藏狀態(tài)st′。因此,我們可以?函數(shù) g 表達(dá)解碼器隱藏層的變換:
st′=g(yt′?1,c,st′?1)st′=g(yt′?1,c,st′?1)
有了解碼器的隱藏狀態(tài)后,我們可以使??定義的輸出層和softmax運(yùn)算來計(jì)算P(yt′|y1,…,yt′?1,c)P(yt′|y1,…,yt′?1,c),例如,基于當(dāng)前時(shí)間步的解碼器隱藏狀態(tài) st′、上?時(shí)間步的輸出st′?1st′?1以及背景變量 c 來計(jì)算當(dāng)前時(shí)間步輸出 yt′ 的概率分布。
4. 訓(xùn)練模型
根據(jù)最?似然估計(jì),我們可以最?化輸出序列基于輸?序列的條件概率:
P(y1,…,yt′?1|x1,…,xT)=T′∏t′=1P(yt′|y1,…,yt′?1,x1,…,xT)P(y1,…,yt′?1|x1,…,xT)=∏t′=1T′P(yt′|y1,…,yt′?1,x1,…,xT)
=T′∏t′=1P(yt′|y1,…,yt′?1,c)=∏t′=1T′P(yt′|y1,…,yt′?1,c)
并得到該輸出序列的損失:
?logP(y1,…,yt′?1|x1,…,xT)=?T′∑t′=1logP(yt′|y1,…,yt′?1,c)?logP(y1,…,yt′?1|x1,…,xT)=?∑t′=1T′logP(yt′|y1,…,yt′?1,c)
在模型訓(xùn)練中,所有輸出序列損失的均值通常作為需要最小化的損失函數(shù)。在上圖所描述的模型預(yù)測中,我們需要將解碼器在上?個(gè)時(shí)間步的輸出作為當(dāng)前時(shí)間步的輸?。與此不同,在訓(xùn)練中我們也可以將標(biāo)簽序列(訓(xùn)練集的真實(shí)輸出序列)在上?個(gè)時(shí)間步的標(biāo)簽作為解碼器在當(dāng)前時(shí)間步的輸?。這叫作強(qiáng)制教學(xué)(teacher forcing)。
5. seq2seq模型預(yù)測
以上介紹了如何訓(xùn)練輸?和輸出均為不定?序列的編碼器—解碼器。本節(jié)我們介紹如何使?編碼器—解碼器來預(yù)測不定?的序列。
在準(zhǔn)備訓(xùn)練數(shù)據(jù)集時(shí),我們通常會(huì)在樣本的輸?序列和輸出序列后面分別附上?個(gè)特殊符號(hào)“”表?序列的終?。我們?cè)诮酉聛淼挠懻撝幸矊⒀?上?節(jié)的全部數(shù)學(xué)符號(hào)。為了便于討論,假設(shè)解碼器的輸出是?段?本序列。設(shè)輸出?本詞典Y(包含特殊符號(hào)“”)的?小為|Y|,輸出序列的最??度為T′。所有可能的輸出序列?共有 O(|y|T′)O(|y|T′) 種。這些輸出序列中所有特殊符號(hào)“”后?的?序列將被舍棄。
5.1 貪婪搜索
貪婪搜索(greedy search)。對(duì)于輸出序列任?時(shí)間步t′,我們從|Y|個(gè)詞中搜索出條件概率最?的詞:
yt′=argmaxy∈YP(y|y1,…,yt′?1,c)yt′=argmaxy∈YP(y|y1,…,yt′?1,c)
作為輸出。?旦搜索出“”符號(hào),或者輸出序列?度已經(jīng)達(dá)到了最??度T′,便完成輸出。我們?cè)诿枋鼋獯a器時(shí)提到,基于輸?序列?成輸出序列的條件概率是∏T′t′=1P(yt′|y1,…,yt′?1,c)∏t′=1T′P(yt′|y1,…,yt′?1,c)。我們將該條件概率最?的輸出序列稱為最優(yōu)輸出序列。而貪婪搜索的主要問題是不能保證得到最優(yōu)輸出序列。
下?來看?個(gè)例?。假設(shè)輸出詞典??有“A”“B”“C”和“”這4個(gè)詞。下圖中每個(gè)時(shí)間步
下的4個(gè)數(shù)字分別代表了該時(shí)間步?成“A”“B”“C”和“”這4個(gè)詞的條件概率。在每個(gè)時(shí)間步,貪婪搜索選取條件概率最?的詞。因此,圖10.9中將?成輸出序列“A”“B”“C”“”。該輸出序列的條件概率是0.5 × 0.4 × 0.4 × 0.6 = 0.048。
接下來,觀察下面演?的例?。與上圖中不同,在時(shí)間步2中選取了條件概率第??的詞“C”
。由于時(shí)間步3所基于的時(shí)間步1和2的輸出?序列由上圖中的“A”“B”變?yōu)榱讼聢D中的“A”“C”,下圖中時(shí)間步3?成各個(gè)詞的條件概率發(fā)?了變化。我們選取條件概率最?的詞“B”。此時(shí)時(shí)間步4所基于的前3個(gè)時(shí)間步的輸出?序列為“A”“C”“B”,與上圖中的“A”“B”“C”不同。因此,下圖中時(shí)間步4?成各個(gè)詞的條件概率也與上圖中的不同。我們發(fā)現(xiàn),此時(shí)的輸出序列“A”“C”“B”“”的條件概率是0.5 × 0.3 × 0.6 × 0.6 = 0.054,?于貪婪搜索得到的輸出序列的條件概率。因此,貪婪搜索得到的輸出序列“A”“B”“C”“”并?最優(yōu)輸出序列。
5.2 窮舉搜索
如果?標(biāo)是得到最優(yōu)輸出序列,我們可以考慮窮舉搜索(exhaustive search):窮舉所有可能的輸出序列,輸出條件概率最?的序列。
雖然窮舉搜索可以得到最優(yōu)輸出序列,但它的計(jì)算開銷 O(|y|T′)O(|y|T′) 很容易過?。例如,當(dāng)|Y| =
10000且T′ = 10時(shí),我們將評(píng)估 1000010=10401000010=1040 個(gè)序列:這?乎不可能完成。而貪婪搜索的計(jì)
算開銷是 O(|y|T′)O(|y|T′),通常顯著小于窮舉搜索的計(jì)算開銷。例如,當(dāng)|Y| = 10000且T′ = 10時(shí),我
們只需評(píng)估 10000?10=10510000?10=105 個(gè)序列。
5.3 束搜索
束搜索(beam search)是對(duì)貪婪搜索的?個(gè)改進(jìn)算法。它有?個(gè)束寬(beam size)超參數(shù)。我們將它設(shè)為 k。在時(shí)間步 1 時(shí),選取當(dāng)前時(shí)間步條件概率最?的 k 個(gè)詞,分別組成 k 個(gè)候選輸出序列的?詞。在之后的每個(gè)時(shí)間步,基于上個(gè)時(shí)間步的 k 個(gè)候選輸出序列,從 k |Y| 個(gè)可能的輸出序列中選取條件概率最?的 k 個(gè),作為該時(shí)間步的候選輸出序列。最終,我們從各個(gè)時(shí)間步的候選輸出序列中篩選出包含特殊符號(hào)“”的序列,并將它們中所有特殊符號(hào)“”后?的?序列舍棄,得到最終候選輸出序列的集合。
束寬為2,輸出序列最??度為3。候選輸出序列有A、C、AB、CE、ABD和CED。我們將根據(jù)這6個(gè)序列得出最終候選輸出序列的集合。在最終候選輸出序列的集合中,我們?nèi)∫韵路謹(jǐn)?shù)最?的序列作為輸出序列:
1LαlogP(y1,…,yL)=1LαT′∑t′=1logP(yt′|y1,…,yt′?1,c)1LαlogP(y1,…,yL)=1Lα∑t′=1T′logP(yt′|y1,…,yt′?1,c)
其中 L 為最終候選序列?度,α ?般可選為0.75。分?上的 Lα 是為了懲罰較?序列在以上分?jǐn)?shù)中較多的對(duì)數(shù)相加項(xiàng)。分析可知,束搜索的計(jì)算開銷為 O(k|y|T′)O(k|y|T′)。這介于貪婪搜索和窮舉搜索的計(jì)算開銷之間。此外,貪婪搜索可看作是束寬為 1 的束搜索。束搜索通過靈活的束寬 k 來權(quán)衡計(jì)算開銷和搜索質(zhì)量。
6. Bleu得分
評(píng)價(jià)機(jī)器翻譯結(jié)果通常使?BLEU(Bilingual Evaluation Understudy)(雙語評(píng)估替補(bǔ))。對(duì)于模型預(yù)測序列中任意的?序列,BLEU考察這個(gè)?序列是否出現(xiàn)在標(biāo)簽序列中。
具體來說,設(shè)詞數(shù)為 n 的?序列的精度為 pn。它是預(yù)測序列與標(biāo)簽序列匹配詞數(shù)為 n 的?序列的數(shù)量與預(yù)測序列中詞數(shù)為 n 的?序列的數(shù)量之?。舉個(gè)例?,假設(shè)標(biāo)簽序列為A、B、C、D、E、F,預(yù)測序列為A、B、B、C、D,那么:
P1=預(yù)測序列中的1元詞組在標(biāo)簽序列是否存在的個(gè)數(shù)預(yù)測序列1元詞組的個(gè)數(shù)之和P1=預(yù)測序列中的1元詞組在標(biāo)簽序列是否存在的個(gè)數(shù)預(yù)測序列1元詞組的個(gè)數(shù)之和
預(yù)測序列一元詞組:A/B/C/D,都在標(biāo)簽序列里存在,所以P1=4/5,以此類推,p2 = 3/4, p3 = 1/3, p4 = 0。設(shè) lenlabel和lenpredlenlabel和lenpred 分別為標(biāo)簽序列和預(yù)測序列的詞數(shù),那么,BLEU的定義為:
exp(min(0,1?lenlabellenpred))k∏n=1p12nnexp(min(0,1?lenlabellenpred))∏n=1kpn12n
其中 k 是我們希望匹配的?序列的最?詞數(shù)。可以看到當(dāng)預(yù)測序列和標(biāo)簽序列完全?致時(shí),
BLEU為1。
因?yàn)槠ヅ漭^??序列?匹配較短?序列更難,BLEU對(duì)匹配較??序列的精度賦予了更?權(quán)重。例如,當(dāng) pn 固定在0.5時(shí),隨著n的增?,0.512≈0.7,0.514≈0.84,0.518≈0.92,0.5116≈0.960.512≈0.7,0.514≈0.84,0.518≈0.92,0.5116≈0.96。另外,模型預(yù)測較短序列往往會(huì)得到較?pn 值。因此,上式中連乘項(xiàng)前?的系數(shù)是為了懲罰較短的輸出而設(shè)的。舉個(gè)例?,當(dāng)k = 2時(shí),假設(shè)標(biāo)簽序列為A、B、C、D、E、F,而預(yù)測序列為A、 B。雖然p1 = p2 = 1,但懲罰系數(shù)exp(1-6/2) ≈ 0.14,因此BLEU也接近0.14。
總結(jié)
- 上一篇: 从此以后谁也别说我不懂LDO了
- 下一篇: 5+App下Mui框架开发仿拼多多App