Beam Search的学习笔记(附代码实现)
引言
Beam Search 是一種受限的寬度優(yōu)先搜索方法,經(jīng)常用在各種 NLP 生成類任務(wù)中,例如機(jī)器翻譯、對(duì)話系統(tǒng)、文本摘要。本文首先介紹 Beam Search 的基本思想,然后再介紹一些beam search的優(yōu)化方法,最后附上自己的代碼實(shí)現(xiàn)。
1. Beam Search的基礎(chǔ)版本
在生成文本的時(shí)候,通常需要進(jìn)行解碼操作,貪心搜索 (Greedy Search) 是比較簡(jiǎn)單的解碼。Beam Search 對(duì)貪心搜索進(jìn)行了改進(jìn),擴(kuò)大了搜索空間,更容易得到全局最優(yōu)解。Beam Search 包含一個(gè)參數(shù) beam size k,表示每一時(shí)刻均保留得分最高的 k 個(gè)序列,然后下一時(shí)刻用這 k 個(gè)序列繼續(xù)生成。示意圖如下所示:
假設(shè)我們生成詞表中有三個(gè)單詞{我,愛(ài),你}。我們?cè)O(shè) K = 2 K=2 K=2。那么我們?cè)诘谝粫r(shí)刻確定兩個(gè)候選輸出是{我,你}。緊接著我們要考慮第二個(gè)輸出,具體步驟如下:
- 確定單詞“我”為第一時(shí)刻輸出,并將其作為第二時(shí)刻輸入,在已知 p ( x , 我 ) p(x,我) p(x,我)的情況下,各個(gè)單詞的輸出概率為3種情況,每個(gè)組合的概率為 P ( 我 ∣ x ) P ( y 2 ∣ x , 我 ) P(我|x)P(y_2|x,我) P(我∣x)P(y2?∣x,我)。
- 同樣我們把“你”也作為第二時(shí)刻輸入,同樣也有三種組合。
- 最后我們?cè)诹N組合中選擇概率最大的三個(gè)組合。
接下來(lái)要做的重復(fù)這個(gè)過(guò)程,逐步生成單詞,直到遇到結(jié)束標(biāo)識(shí)符停止。最后得到概率最大的那個(gè)生成序列。其概率為:
以上就是Beam search算法的思想,當(dāng)beam size=1時(shí),就變成了貪心算法。
2. Beam Search的優(yōu)化
Beam search算法也有許多改進(jìn)的地方。
2.1 Length normalization:懲罰短句
根據(jù)最后的概率公式可知,該算法傾向于選擇最短的句子,因?yàn)樵谶@個(gè)連乘操作中,每個(gè)因子都是小于1的數(shù),因子越多,最后的概率就越小。解決這個(gè)問(wèn)題的方式,最后的概率值除以這個(gè)生成序列的單詞數(shù),這樣比較的就是每個(gè)單詞的平均概率大小。此外,連乘因子較多時(shí),可能會(huì)超過(guò)浮點(diǎn)數(shù)的最小值,可以考慮取對(duì)數(shù)來(lái)緩解這個(gè)問(wèn)題。谷歌給的公式如下:
其中α∈[0,1],谷歌建議取值為[0.6,0.7]之間,α用于length normalization。
2.2 Coverage normalization:懲罰重復(fù)
另外我們?cè)谛蛄械叫蛄腥蝿?wù)中經(jīng)常會(huì)發(fā)現(xiàn)一個(gè)問(wèn)題,2016 年, 華為諾亞方舟實(shí)驗(yàn)室的論文提到,機(jī)器翻譯的時(shí)候會(huì)存在over translation or undertranslation due to attention coverage。 作者提出coverage-based atttention機(jī)制來(lái)解決coverage 問(wèn)題。 Google machine system 利用了如下的方式進(jìn)行了length normalization 和 coverage penalty。
還是上述公式,β用于控制coverage penalty
coverage penalty 主要用于使用 Attention 的場(chǎng)合,通過(guò) coverage penalty 可以讓 Decoder 均勻地關(guān)注于輸入序列 x x x 的每一個(gè) token,防止一些 token 獲得過(guò)多的 Attention。
2.3 End of sentence normalization:抑制長(zhǎng)句
有的時(shí)候我們發(fā)現(xiàn)生成的序列一直生成下去不會(huì)停止,有的時(shí)候我們可以顯式的設(shè)置最大生成長(zhǎng)度進(jìn)行控制,這里我們可以采用下式來(lái)進(jìn)行約束:
其中 ∣ X ∣ |X| ∣X∣是source的長(zhǎng)度, ∣ Y ∣ |Y| ∣Y∣是當(dāng)前target的長(zhǎng)度,那么由上式可知,target長(zhǎng)度越長(zhǎng)的話,上述得分越低,這樣就會(huì)防止出現(xiàn)生成一直不停止的情況。
3. Beam Search的代碼實(shí)現(xiàn)
總的來(lái)說(shuō),beam search不保證全局最優(yōu),但是比greedy search搜索空間更大,一般結(jié)果比greedy search要好。下面附上一些代碼實(shí)現(xiàn):
首先,首先定義一個(gè) Beam 類,作為一個(gè)存放候選序列的容器,屬性需維護(hù)當(dāng)前序列中的 token 以及對(duì)應(yīng)的對(duì)數(shù)概率,同時(shí)還需維護(hù)跟當(dāng)前 timestep 的 Decoder 相關(guān)的一些變量。此外,還需要給 Beam 類實(shí)現(xiàn)兩個(gè)函數(shù):一個(gè) extend 函數(shù)用以擴(kuò)展當(dāng)前的序列(即添加新的 time step的 token 及相關(guān)變量);一個(gè) score 函數(shù)用來(lái)計(jì)算當(dāng)前序列的分?jǐn)?shù)(在Beam類下的seq_score函數(shù)中有Length normalization以及Coverage normalization)。
class Beam(object):def __init__(self,tokens,log_probs,decoder_states,coverage_vector):self.tokens = tokensself.log_probs = log_probsself.decoder_states = decoder_statesself.coverage_vector = coverage_vectordef extend(self,token,log_prob,decoder_states,coverage_vector):return Beam(tokens=self.tokens + [token],log_probs=self.log_probs + [log_prob],decoder_states=decoder_states,coverage_vector=coverage_vector)def seq_score(self):"""This function calculate the score of the current sequence."""len_Y = len(self.tokens)# Lenth normalizationln = (5+len_Y)**config.alpha / (5+1)**config.alphacn = config.beta * torch.sum( # Coverage normalizationtorch.log(config.eps +torch.where(self.coverage_vector < 1.0,self.coverage_vector,torch.ones((1, self.coverage_vector.shape[1])).to(torch.device(config.DEVICE)))))score = sum(self.log_probs) / ln + cnreturn scoredef __lt__(self, other):return self.seq_score() < other.seq_score()def __le__(self, other):return self.seq_score() <= other.seq_score()接著我們需要實(shí)現(xiàn)一個(gè) best_k 函數(shù),作用是將一個(gè) Beam 容器中當(dāng)前 time step 的變量傳入 Decoder 中,計(jì)算出新一輪的詞表概率分布,并從中選出概率最大的 k 個(gè) token 來(lái)擴(kuò)展當(dāng)前序列(其中加入了End of sentence normalization),得到 k 個(gè)新的候選序列。
def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs):"""Get best k tokens to extend the current sequence at the current time step."""# use decoder to generate vocab distribution for the next tokenx_t = torch.tensor(beam.tokens[-1]).reshape(1, 1)x_t = x_t.to(self.DEVICE)# Get context vector from attention network.context_vector, attention_weights, coverage_vector = \self.model.attention(beam.decoder_states,encoder_output,x_padding_masks,beam.coverage_vector)# Replace the indexes of OOV words with the index of OOV token# to prevent index-out-of-bound error in the decoder.p_vocab, decoder_states, p_gen = \self.model.decoder(replace_oovs(x_t, self.vocab),beam.decoder_states,context_vector)final_dist = self.model.get_final_distribution(x,p_gen,p_vocab,attention_weights,torch.max(len_oovs))# Calculate log probabilities.log_probs = torch.log(final_dist.squeeze())# Filter forbidden tokens.# EOS token penalty. Follow the definition in# https://opennmt.net/OpenNMT/translation/beam_search/.log_probs[self.vocab.EOS] *= \config.gamma * x.size()[1] / len(beam.tokens)log_probs[self.vocab.UNK] = -float('inf')# Get top k tokens and the corresponding logprob.topk_probs, topk_idx = torch.topk(log_probs, k)# Extend the current hypo with top k tokens, resulting k new hypos.best_k = [beam.extend(x,log_probs[x],decoder_states,coverage_vector) for x in topk_idx.tolist()]return best_k最后我們實(shí)現(xiàn)主函數(shù) beam_search。初始化encoder、attention和decoder的輸?,然后對(duì)于每?個(gè)decodestep,對(duì)于現(xiàn)有的k個(gè)beam,我們分別利?best_k函數(shù)來(lái)得到各?最佳的k個(gè)extended beam,也就是每個(gè)decode step我們會(huì)得到k*k個(gè)新的beam,然后只保留分?jǐn)?shù)最?的k個(gè),作為下?輪需要擴(kuò)展的k個(gè)beam。為了只保留分?jǐn)?shù)最?的k個(gè)beam,我們可以??個(gè)堆(heap)來(lái)實(shí)現(xiàn),堆的中只保存k個(gè)節(jié)點(diǎn),根結(jié)點(diǎn)保存分?jǐn)?shù)最低的beam。
def beam_search(self,x,max_sum_len,beam_width,len_oovs,x_padding_masks):"""Using beam search to generate summary."""# run body_sequence input through encoderencoder_output, encoder_states = self.model.encoder(replace_oovs(x, self.vocab))coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE)# initialize decoder states with encoder forward statesdecoder_states = self.model.reduce_state(encoder_states)# initialize the hypothesis with a class Beam instance.init_beam = Beam([self.vocab.SOS],[0],decoder_states,coverage_vector)# get the beam size and create a list for stroing current candidates# and a list for completed hypothesisk = beam_widthcurr, completed = [init_beam], []# use beam search for max_sum_len (maximum length) stepsfor _ in range(max_sum_len):# get k best hypothesis when adding a new tokentopk = []for beam in curr:# When an EOS token is generated, add the hypo to the completed# list and decrease beam size.if beam.tokens[-1] == self.vocab.EOS:completed.append(beam)k -= 1continuefor can in self.best_k(beam,k,encoder_output,x_padding_masks,x,torch.max(len_oovs)):# Using topk as a heap to keep track of top k candidates.# Using the sequence scores of the hypos to campare# and object ids to break ties.add2heap(topk, (can.seq_score(), id(can), can), k)curr = [items[2] for items in topk]# stop when there are enough completed hypothesisif len(completed) == beam_width:break# When there are not engouh completed hypotheses,# take whatever when have in current best k as the final candidates.completed += curr# sort the hypothesis by normalized probability and choose the best oneresult = sorted(completed,key=lambda x: x.seq_score(),reverse=True)[0].tokensreturn result總結(jié)
以上是生活随笔為你收集整理的Beam Search的学习笔记(附代码实现)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: vue3使用xlsx 导出excel ,
- 下一篇: [计算机网络]应用层协议,HTTP,SM