Chapter1-5_Speech_Recognition(Alignment of HMM, CTC and RNN-T)
文章目錄
- 1 為什么需要Alignment
- 2 窮舉所有的alignment
- 2.1 HMM的對齊
- 2.2 CTC的對齊
- 2.3 RNN-T的對齊
- 3 小結
本文為李弘毅老師【Speech Recognition - Alignment of HMM, CTC and RNN-T (optional)】的課程筆記,課程視頻youtube地址,點這里👈(需翻墻)。
下文中用到的圖片均來自于李宏毅老師的PPT,若有侵權,必定刪除。
文章索引:
上篇 - 1-4 HMM
下篇 - 1-6 RNN-T Training
總目錄
1 為什么需要Alignment
現在所有的seq2seq的模型forward的過程,從宏觀上來講,就是我們輸入一個序列XXX,可以輸出產生任意序列YYY的概率。
然后decode的時候,我們就是要找到一個序列YYY,使得P(Y∣X)P(Y|X)P(Y∣X)最大。在找這個序列的時候,一般不會窮舉,而是通過Beam Search去做。
Decoding:Y?=argmax?YlogP(Y∣X)Decoding:Y^*= \underbrace{argmax}_Y logP(Y|X) Decoding:Y?=Yargmax??logP(Y∣X)
像LAS這樣的的輸出中沒有額外的符號的模型,其結果就直接是P(Y∣X)P(Y|X)P(Y∣X)了。比如上圖要計算輸出序列ababab的概率就是
P(Y∣X)=P(a∣X)P(b∣a,X)P(<EOS>∣ab,X)P(Y|X)=P(a|X)P(b|a,X)P(<EOS>|ab,X) P(Y∣X)=P(a∣X)P(b∣a,X)P(<EOS>∣ab,X)
如果有點忘了LAS的decoder是長什么樣的話,可以看下面這幅圖。
在訓練的時候,我們就希望訓練出一組模型參數θ\thetaθ下,使得模型在decode的時候,得到標簽Y^\hat{Y}Y^的概率是最大的。
Training:argmax?θlogPθ(Y^∣X)Training: \underbrace{argmax}_{\theta}logP_{\theta}(\hat{Y}|X) Training:θargmax??logPθ?(Y^∣X)
以上的是模型輸出符號都是字典里的字符的情況,但是,當用CTC或者RNN-T這樣的模型時,我們的結果中是會出現?\phi?這樣的占位符的,那么就不能簡單地直接計算P(Y∣X)P(Y|X)P(Y∣X)了。而HMM這樣的模型,會需要去掉重復的字符,故也不能直接計算。
這個時候,我們需要計算的是所有能夠通過相應的對齊規則對齊到YYY的輸出序列hhh概率之和。
P(Y∣X)=∑h∈align(Y)P(h∣X)P(Y|X) = \sum_{h \in align(Y)}P(h|X) P(Y∣X)=h∈align(Y)∑?P(h∣X)
這就是我們要講alignment的原因。
下文會講到的如何窮舉所有可能的alignment。也就是上面公式中h∈align(Y)h \in align(Y)h∈align(Y)這個集合是怎么來的。
2 窮舉所有的alignment
為了方便說明,我們假設現在輸入的sequence長度為6,輸出的sequence為"cat"。由于HMM,CTC和RNN-T對齊的規則有所不同,故他們在找h∈align(Y)h \in align(Y)h∈align(Y)這個集合的時候,也會有些不同。
2.1 HMM的對齊
HMM的對齊規則為:
- 去掉所有的相鄰重復字符
所以,HMM在找h∈align(Y)h \in align(Y)h∈align(Y)的時候,就是在"cat"的基礎上,加入重復的字符,使得序列的長度等于T=6T=6T=6。寫成演算法的話,就是下圖中灰色方框里這樣。比如我們的目標是"cat",那么N=3N=3N=3,然后我們從"c"開始選擇重復一次或者多次,接著再去重復"a"和"t",我們需要保證所有的字符都至少出現一次,且它們出現的次數之和為輸入序列的長度TTT。
HMM要找的所有alignment都可以畫在一個表格當中。這個表格的起點為左上角的橘黃色的點,終點為右下角藍色的點。往右下方走,表示選擇下一個token,往正右方走,表示重復一個token。我們要在保證每次只能往右下或者正右的情況下,從橘點走到藍點。每一種走法的路徑,就是一個alignment。
2.2 CTC的對齊
CTC的對齊規則為:
- 首先合并所有的相鄰重復字符
- 然后去除掉所有的?\phi?
所以,CTC在找h∈align(Y)h \in align(Y)h∈align(Y)的時候,就是在"cat"的基礎上,加入重復的字符和?\phi?,使得序列的長度等于T=6T=6T=6。寫成演算法的話,就是下圖中灰色方框里這樣。比如我們的目標是"cat",那么N=3N=3N=3,然后我們從"c"或者“?\phi?”開始選擇重復一次或者多次,接著再去重復"a","?\phi?“和"t”,"?\phi?",我們需要保證所有的字符都至少出現一次,"?\phi?“可以出現也可不出現,且字符和”?\phi?"出現的次數之和為輸入序列的長度TTT。
CTC要找的所有alignment同樣也可以畫在一個表格當中。這個表格的起點為左上角的橘黃色的點,終點有兩個,為右下角藍色的點。
第一步,我們可以選擇字符或者“?\phi?”;如果選擇了字符"c",那么接下來可以有3種選擇,分別是往正右重復,往右下對角插入一個"?\phi?",往右下走馬步插入字符"a"。
如果我們選擇的是"?\phi?",那么我們就只有2種選擇,分別是往正右重復"?\phi?“或者往右下對角插入字符"c”。這個時候,是不能走右下馬步重復?\phi?的。
總結一下,就是在"?\phi?"行的時候,有正右或者右下對角2種選擇,在字符行的時候,有正右或者右下對角或者右下馬步3種選擇。
還有一種特殊情況需要注意的是,如果走右下角馬步得到的字符和當前字符是相同的時候,不同走右下角馬步。
基于以上的這些規則,從橘點走到右下腳兩個藍點中的任意一個所經過的路徑都是一個合理的alignment。
2.3 RNN-T的對齊
RNN-T的對齊規則為:
- 去除掉所有的?\phi?
所以,RNN-T在找h∈align(Y)h \in align(Y)h∈align(Y)的時候,就是在"cat"的基礎上,加入T=6T=6T=6個?\phi?。寫成演算法的話,就是下圖中灰色方框里這樣。我們在每個字符之間都可以插入數量不等的"?\phi?",但是末尾至少要有1個"?\phi?",然后所有"?\phi?“的個數之和為T=6T=6T=6。
RNN-T要找的所有alignment同樣也可以畫在一個表格當中,不過這個表格和之前的有所不同。這個表格的起點為左上角的藍色的點,終點為右下角藍色的點。每往正右走一步就是插入一個”?\phi?",每往正下走一步就是插入一個字符,直到走到右下角的藍點,所經過的路徑都是一個合理的alignment。
3 小結
HMM、CTC和RNN-T都可以用如下圖所示的HMM專用的狀態轉移圖來表示。其實也就是上文所述的東西,我覺得就算不看下面這個圖也無所謂,所以這里就不講了。
總結
以上是生活随笔為你收集整理的Chapter1-5_Speech_Recognition(Alignment of HMM, CTC and RNN-T)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 利用动态规划(DP)解决 Coin Ch
- 下一篇: LeetCode 1879. 两个数组最