原文鏈接,感謝原作者
目錄
一、文本生成和翻譯的基本流程
翻譯類模型的訓練和解碼
訓練過程?
解碼過程
生成類模型的訓練和解碼(GPT系列)
訓練過程
解碼過程
二、解碼策略
1、貪心搜索(greedy search)
2、beam_search集束搜索
3、隨機sampling
4、Top-K Sampling和Top-p (nucleus) sampling
Top-K Sampling
Top-p (nucleus) sampling
?三、transformer中的解碼使用
? ? ? ? 文本生成和文本翻譯的效果不僅僅在于模型層面的好壞,同時預測階段的解碼策略也是比較重要,不同的解碼策略得出的效果也是不同的。經過學者們多年的研究,目前就我所知的文本生成相關的解碼策略主要有貪心搜索(greedy search)、beam_search集束搜索、隨機sampling、top-k sampling和Top-p Sampling,今天我們主要聊聊這幾種文本解碼策略算法。
一、文本生成和翻譯的基本流程 翻譯類模型的訓練和解碼 訓練過程? 翻譯類任務的流程是一個src輸入對應一個tag輸入,一般而言,src長度和tag長度不一樣的;一個簡單的流程圖如下圖所示:
模型訓練的結果是和tag長度一樣的一個向量,output[T,B,D]經過一個分類全連接層得到[T,B]的概率分布,這個就和tag的輸入[T,B]計算loss;
解碼過程 如下圖所示,模型訓練好以后,解碼的初始就是src的embedding加上tag端的起始字符<cls>等特殊的字符,解碼輸出得到第一個字符token然后把這個token添加到tag端輸入,繼續解碼得到第二個token......重復不斷的解碼,每一次解碼都是需要過一次模型推理,所以比較耗時;只到碰到結束字符或者最大長度。
生成類模型的訓練和解碼(GPT系列) 訓練過程 GPT模型的訓練過程直接輸入一段自然文本,然后輸出其embedding,然后再經過一個分類器,得到logits[B,L,V];同時把輸入文本作為標簽,計算交叉熵損失。模型的輸入就是inputids [B,L]-------->embedding[B,L,D]------->logits[B,L,V]。
?
解碼過程 同上面類似也是把當前解碼結果token和之前的tokens合并起來作為輸入解碼得到下一個token。
二、解碼策略 ? ? 上面通過示意圖簡單的解釋了一下生成類任務的模型訓練和解碼過程以及中間的向量維度變化,最后解碼的結果好壞出了和模型本身有關,同時也與采用什么樣的解碼策略也是很相關的。
1、貪心搜索(greedy search) 預測階段得到的概率分布,連接全連接層后,可以得到一個序列的概率分布[(B*S),vocab_size]——含義就是每個字在詞表上的概率分布,共有B*S個字。怎么樣通過這個概率分布得到最合理的序列。一種很直觀的做法就是從每個字的概率分布中取它的最大概率的那個可能性,直到整個序列完成或者發現終止符[SEP]。簡單實現,代碼如下:
def gen_nopeek_mask (length ): """ Returns the nopeek mask Parameters: length (int): Number of tokens in each sentence in the target batch Returns: mask (arr): tgt_mask, looks like [[0., -inf, -inf], [0., 0., -inf], [0., 0., 0.]] """ mask = torch.triu(torch.ones(length, length)) mask = mask.float ().masked_fill(mask == 0 , float ('-inf' )).masked_fill(mask == 1 , float (0.0 )) return mask def greedy_search_decode (model, src,src_key_padding_mask, max_len:int = 64 , start_symbol:int = 1 ): """ :param model: Transformer model :param src: the encoder input :param max_len: 序列最大長度 :return:ys 這個就是預測的具體序列 解碼的時候這幾個mask是不能夠少的 """ src_mask = gen_nopeek_mask(src.shape[1 ]).to(device) memory_key_padding_mask = src_key_padding_mask ys = torch.ones(1 , 1 ).fill_(start_symbol).type_as(src.data) for i in range (max_len-1 ): tar_mask = gen_nopeek_mask(ys.shape[1 ]).to(device) out = model.forward(src, ys, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None ,src_mask=src_mask,tar_mask=tar_mask,memory_key_padding_mask=memory_key_padding_mask) out = out[:,-1 ,:] _, next_word = torch.max (out, dim=1 ) next_word = next_word.data[0 ] if next_word != 2 : ys = torch.cat([ys, torch.ones(1 , 1 ).type_as(src.data).fill_(next_word)], dim=1 ) else : break return ys 上面實現的缺陷就是不能并行的解碼batch>1的情形,可以適當修改適應并行處理,每次batch內的數據每次解碼后,做一個判定,是否batch內的每一行數據都出現了結束字符。判定代碼就是:
(ys ==
2 ).
sum (
1 ).
bool ().
all ()
判定ys的每一行是否出現過2(結束符號)這個元素
解碼完整代碼如下圖
def greedy_search_decode (model, src, src_key_padding_mask, max_len: int = 64 , start_symbol: int = 1 , bs:int =32 ): """ :param model: Transformer model :param src: the encoder input :param max_len: 序列最大長度 :return:ys 這個就是預測的具體序列 解碼的時候這幾個mask是不能夠少的 """ src_mask = gen_nopeek_mask(src.shape[1 ]).to(device) memory_key_padding_mask = src_key_padding_mask ys = torch.ones(bs, 1 ).fill_(start_symbol).type_as(src.data) for i in range (max_len - 1 ): tar_mask = gen_nopeek_mask(ys.shape[1 ]).to(device) out = model.forward(src, ys, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None , src_mask=src_mask, tar_mask=tar_mask, memory_key_padding_mask=memory_key_padding_mask) out = out[:, -1 , :] _, next_word = torch.max (out, dim=1 ) next_word = next_word.data[0 ] ys = torch.cat([ys, next_word], dim=1 ) if (ys == 2 ).sum (1 ).bool ().all (): break return ys 解碼舉例如下
?the nice woman 是每個時間步當前的最佳選擇概率為0.5*0.4=0.2,但是從圖上看概率最大的結果并不是這個the dog has 才具有整句最大的概率0.4*0.9 = 0.36;很明顯的貪心搜索(greedy search)的缺點就是得出的序列并不一定具有整句最大概率,它很有可能遺漏掉一個比較小的當前概率后面的非常大概率的序列。為了避免這種情況,學者們提出了beam_search算法。
2、beam_search集束搜索 為了避免上述貪心搜索遺漏掉后面大概率的序列,beam search算法提出每次都保留當前最大的beam_num個結果。把當前beam_num個結果分別輸入到模型中進行解碼,每個序列又新生成v個新結果,共計beam_num*v個結果,排序選擇最佳的beam_num個結果;然后重復上述過程,直到解碼完成,最后從beam_num個結果選擇出概率積最大的那個序列。——即每一步解碼過程中都是保留前beam_num個最大的結果,最后才得出概率最大的那個。
以beam_num為2進行舉例,圖片來自——(全面了解Beam Search 1)
第一步解碼,我們選擇概率最大的兩個單詞[A, C],然后分別帶入第二步解碼,分別得到[AA, AB, AC, AD, AE, CA, CB, CC, CD, CE] 10種情況,這里僅保留最優的兩種情況[AB, CE],然后再繼續帶入第三步解碼,以此類推.....最后得到整體概率最大的序列。
bs=1時,實現beam search還是比較簡單的,直接在貪心搜索的代碼上做修改,記錄當前最佳的beam_num個序列以及得分,然后每一步結果從beam_num*v的結果中做排序得到新的beam_num個結果。
當bs>1的時候,要實現一個高效的beam search還是比較麻煩的,參考了全面了解Beam Search 1和世界第一NLP實現庫huggingface的transformers中的源碼,修改如下的beam search代碼:
import torchimport torch.nn.functional as Ffrom einops import rearrange """ batch_size為n 這樣的處理 """ class BeamHypotheses (object ): def __init__ (self,num_beams,max_length,length_penalty ): self.max_length=max_length-1 self.length_penalty=length_penalty self.num_beams=num_beams self.beams=[] self.worst_score=1e9 def __len__ (self ): return len (self.beams) def add (self,hyp,sum_logprobs ): score=sum_logprobs / len (hyp) ** self.length_penalty if len (self) < self.num_beams or score > self.worst_score: self.beams.append((score, hyp)) if len (self) > self.num_beams: sorted_scores=sorted ([(s,idx)for idx, (s, _) in enumerate (self.beams)]) del self.beams[sorted_scores[0 ][1 ]] self.worst_score = sorted_scores[1 ][0 ] else : self.worst_score = min (score, self.worst_score) def is_done (self,best_sum_logprobs,cur_len ): if len (self) < self.num_beams: return False else : cur_score = best_sum_logprobs / cur_len ** self.length_penalty ret = self.worst_score >= cur_score return ret def gen_nopeek_mask (length ): """ Returns the nopeek mask Parameters: length (int): Number of tokens in each sentence in the target batch Returns: mask (arr): tgt_mask, looks like [[0., -inf, -inf], [0., 0., -inf], [0., 0., 0.]] """ mask = rearrange(torch.triu(torch.ones(length, length)) == 1 , 'h w -> w h' ) mask = mask.float ().masked_fill(mask == 0 , float ('-inf' )).masked_fill(mask == 1 , float (0.0 )) return mask def beam_sizing (num_beams,src,src_key_padding_mask ): temp1 = src temp2 = src_key_padding_mask for i in range (num_beams-1 ): temp1 = torch.cat([temp1,src],dim=0 ) temp2 = torch.cat([temp2,src_key_padding_mask],dim=0 ) index = 0 for i in range (src.shape[0 ]): for _ in range (num_beams): temp1[index,...] = src[i,...] temp2[index,...] = src_key_padding_mask[i,...] index += 1 src = temp1 src_key_padding_mask = temp2 return src,src_key_padding_mask def beam_search (device,model,src,src_key_padding_mask,sos_token_id:int =1 ,pad_token_id:int =0 ,eos_token_id:int = 2 ,max_length:int = 20 ,num_beams:int =6 ,vocab_size:int =5993 ):batch_size = src.shape[0 ] src_mask = gen_nopeek_mask(src.shape[1 ]).to(device) src,src_key_padding_mask = beam_sizing(num_beams,src,src_key_padding_mask) memory_key_padding_mask = src_key_padding_mask beam_scores = torch.zeros((batch_size, num_beams)).to(device) beam_scores[:, 1 :] = -1e9 beam_scores = beam_scores.view(-1 ) done = [False for _ in range (batch_size)] generated_hyps = [ BeamHypotheses(num_beams, max_length, length_penalty=0.7 ) for _ in range (batch_size) ] input_ids = torch.full((batch_size * num_beams, 1 ), sos_token_id, dtype=torch.long).to(device) cur_len = 1 while cur_len < max_length: tar_mask = gen_nopeek_mask(input_ids.shape[1 ]).to(device) memory_key_padding_mask = src_key_padding_mask outputs,_= model.forward(src, input_ids, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=None ,src_mask=src_mask,tar_mask=tar_mask,memory_key_padding_mask=memory_key_padding_mask) next_token_logits = outputs[:, -1 , :] scores = F.log_softmax(next_token_logits, dim=-1 ) next_scores = scores + beam_scores[:, None ].expand_as(scores) next_scores = next_scores.view( batch_size, num_beams * vocab_size ) next_scores, next_tokens = torch.topk(next_scores, 2 *num_beams, dim=1 , largest=True , sorted =True ) next_batch_beam = [] for batch_idx in range (batch_size): if done[batch_idx]: next_batch_beam.extend([(0 , pad_token_id, 0 )] * num_beams) continue next_sent_beam = [] for beam_token_rank, (beam_token_id, beam_token_score) in enumerate ( zip (next_tokens[batch_idx], next_scores[batch_idx]) ): beam_id = beam_token_id // vocab_size token_id = beam_token_id % vocab_size effective_beam_id = batch_idx * num_beams + beam_id if (eos_token_id is not None ) and (token_id.item() == eos_token_id): is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams if is_beam_token_worse_than_top_num_beams: continue generated_hyps[batch_idx].add( input_ids[effective_beam_id].clone(), beam_token_score.item(), ) else : next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) if len (next_sent_beam) == num_beams: break done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( next_scores[batch_idx].max ().item(), cur_len ) next_batch_beam.extend(next_sent_beam) if all (done): break beam_scores = beam_scores.new([x[0 ] for x in next_batch_beam]) beam_tokens = input_ids.new([x[1 ] for x in next_batch_beam]) beam_idx = input_ids.new([x[2 ] for x in next_batch_beam]) input_ids = input_ids[beam_idx, :] src = src[beam_idx,...] src_key_padding_mask = src_key_padding_mask[beam_idx,...] input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1 )], dim=-1 ) cur_len = cur_len + 1 for batch_idx in range (batch_size): if done[batch_idx]: continue for beam_id in range (num_beams): effective_beam_id = batch_idx * num_beams + beam_id final_score = beam_scores[effective_beam_id].item() final_tokens = input_ids[effective_beam_id] generated_hyps[batch_idx].add(final_tokens, final_score) output_num_return_sequences_per_batch = num_beams output_batch_size = output_num_return_sequences_per_batch * batch_size sent_lengths = input_ids.new(output_batch_size) best = [] best_score = [] for i, hypotheses in enumerate (generated_hyps): sorted_hyps = sorted (hypotheses.beams, key=lambda x: x[0 ]) for j in range (output_num_return_sequences_per_batch): effective_batch_idx = output_num_return_sequences_per_batch * i + j temp = sorted_hyps.pop() best_hyp = temp[1 ] best_s = temp[0 ] sent_lengths[effective_batch_idx] = len (best_hyp) best.append(best_hyp) best_score.append(best_s) if sent_lengths.min ().item() != sent_lengths.max ().item(): sent_max_len = min (sent_lengths.max ().item() + 1 , max_length) decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) for i, hypo in enumerate (best): decoded[i, : sent_lengths[i]] = hypo if sent_lengths[i] < max_length: decoded[i, sent_lengths[i]] = eos_token_id else : decoded = torch.stack(best).type (torch.long) best_score = torch.tensor(best_score).type_as(next_scores) return decoded,best_score 雖然解決上貪心搜索的缺陷,但是beam search解碼策略也有它的缺陷。從實際使用效果來看,beam search很容易重復的出現之前的字符,尤其是在文本生成任務上,機器翻譯上效果還行。
?How to generate text: using different decoding methods for language generation with Transformers中給出的例子可以看出在生成很短的一句話后,就開始重復了。為了解決這個問題,學者們提出了隨機sampling的算法
3、隨機sampling 隨機采樣顧名思義就是對在解碼的時候,在下一個token生成的時候,直接隨機的進行采樣。對于greedy方法的好處是,我們生成的文字開始有了一些隨機性,不會總是生成很機械的回復了。存在的問題就很明顯了——生成的話術上下文不連貫,語義上可能相互矛盾、也是容易出現一些奇怪的詞。
4、Top-K Sampling和Top-p (nucleus) sampling 論文The Curious Case of Neural Text Degeneration中提出一個很有意思的語言現象——
人類的語言總是出人意料的,并不是如同beam search中選擇語言模型中概率最大的序列。就是beam search解碼策略的結果less surprising!為此論文就基于Top-K Sampling改進得到了核采樣Top-p (nucleus) sampling,下面就來聊一聊Top-K Sampling和Top-p (nucleus) sampling。
Top-K Sampling 這個是在隨機sampling的基礎上改進而來,既然在整個loghits概率分布上做隨機采樣會導致上下文不連貫,語義上可能相互矛盾、出現奇怪詞語等問題,那能不能選取概率最大的K個token,重新形成概率分布,然后再做多項式分布抽樣。思想很簡單,torch實現起來也不困難。實際使用效果在GPT2模型上得到了很高的提升,GPT2生成的語句非常通順流利,且重復token大幅度減少。
?如圖顯示的就是K=6的時候,解碼第一步6個token占據了整體tokens的三分之二,第二步則占用了99%,并且這些token都是比較合理的,同時采樣的時候也采用了多項式隨機采樣——這樣的話就會得到比較通順流利的話語,也沒有重復的詞和奇怪的詞。
該方法的難點在于K值如何選取
每一步解碼過程中,logits的概率分布都是不一樣的,在動態改變,固定的K值有可能造成取到的token是低概率的不合理的token;另外K取值過大又會和之前的隨機sampling一樣生成的話術上下文不連貫,語義上可能相互矛盾、也是容易出現一些奇怪的詞;K過小的話,又會導致生成的語句多樣性變差,less surprising!最好是K能動態的適應每一步解碼的logits!為此有學者提出了核采樣Top-p (nucleus) sampling
Top-p (nucleus) sampling 和Top-K Sampling不同的去一個固定的K值,Top-p (nucleus) sampling對整個logits從大到小累積概率,只要累積概率大于一個閾值,就把這些選取的token構成新的分布,然后采取多項式抽樣,得到解碼的next token!
示例中累積概率閾值p = 0.92 ,第一步解碼中采樣從9個token中進行;第二步解碼從3個token中進行;這樣就可以動態的適應logtis,采取不同的K值。不過有一個點就是累積概率閾值P也是不溶于確定的,大多采用經驗值。
當然從使用效果上來講,Top-K Sampling和Top-p (nucleus) sampling都是比較不錯的;當然實際使用過程中也是可以把Top-p (nucleus) sampling和Top-K Sampling結合起來,避免概率很小的token作為候選者,同時也保持動態性。
top-k和top-p 過濾代碼:
def top_k_top_p_filtering_batch (logits, top_k=0 , top_p=0.0 , filter_value=-float ('Inf' ) ): """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering Args: logits: logits distribution shape (vocabulary size) top_k > 0: keep only top k tokens with highest probability (top-k filtering). top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 """ top_k = min (top_k, logits.size(-1 )) if top_k > 0 : for i in range (logits.shape[0 ]): indices_to_remove = logits[i] < torch.topk(logits[i], top_k)[0 ][..., -1 , None ] logits[i][indices_to_remove] = filter_value if top_p > 0.0 : for i in range (logits.shape[0 ]): sorted_logits, sorted_indices = torch.sort(logits[i], descending=True ) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1 ), dim=-1 ) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1 :] = sorted_indices_to_remove[..., :-1 ].clone() sorted_indices_to_remove[..., 0 ] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] logits[i][indices_to_remove] = filter_value return logits 然后直接調用該過濾算法進行解碼
curr_input_tensor = input_ids.to(device) generated = [] for index in range (args.max_len): outputs = model(input_ids=curr_input_tensor) next_token_logits = outputs[0 ][:,-1 :] if index>=1 : for i in range (gen_finall.shape[0 ]): gen_token_ids = gen_finall[i].clone() gen_token_ids = list (set (gen_token_ids.detach().cpu().tolist())) for id in gen_token_ids: next_token_logits[i:i+1 ,:,id :id +1 ] /= args.repetition_penalty next_token_logits = next_token_logits / args.temperature token_unk_id = tokenizer.convert_tokens_to_ids('[UNK]' ) next_token_logits[:,:,token_unk_id:token_unk_id+1 ] = -float ('Inf' ) filtered_logits = top_k_top_p_filtering_batch(next_token_logits, top_k=args.topk, top_p=args.topp) next_token = curr_input_tensor[:,-1 :].clone() for i in range (next_token.shape[0 ]): next_token[i] = torch.multinomial(F.softmax(filtered_logits[i].squeeze(0 ), dim=-1 ), num_samples=1 ) generated.append(next_token) gen_finall = torch.cat(generated,dim=1 ) if (gen_finall==tokenizer.sep_token_id).sum (1 ).bool ().all (): break curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=1 ) 前文聊了文本生成和翻譯的基本流程、解碼策略的一些基本原理和思想以及解碼策略的實現,當然更優雅的用法就是直接調用世界第一NLP實現庫huggingface的transformers中關于文本翻譯類或者生成類的解碼函數。generation_utils.py提供了多種解碼方式greedy search、beam search、sampling(直接隨機sampling、top-K和Top-P)、beam_sample(beam_search+top-K和Top-P)和group_beam。至于其他的一些功能,需要讀者自己去閱讀源碼。
解碼很簡單,代碼如下,加載模型,喂入數據,解碼,得到結果。
from transformers import AutoTokenizer, AutoModelForSeq2SeqLMimport osos.environ['CUDA_VISIBLE_DEVICES' ] = '1' from tqdm import tqdmfrom torch.utils.data import DataLoaderimport torchfrom data_reader.dataReader_zh2en import DataReaderif __name__ == '__main__' :tokenizer = AutoTokenizer.from_pretrained("./pretrained_models/MarianMTModel_zh2en" ) model = AutoModelForSeq2SeqLM.from_pretrained("./pretrained_models/MarianMTModel_zh2en" ) dataset = DataReader(tokenizer, filepath='data/test_sample.csv' ) test_dataloader = DataLoader(dataset=dataset,batch_size=4 ) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) finanl_result = [] for batch in tqdm(test_dataloader,desc='translation prediction' ): for k, v in batch.items(): batch[k] = v.to(device) batch = {'input_ids' : batch['input_ids' ], 'attention_mask' : batch['attention_mask' ]} translation = model.generate(**batch, top_k=5 , num_return_sequences=1 ,num_beams=1 ) batch_result = tokenizer.batch_decode(translation, skip_special_tokens=True ) finanl_result.extend(batch_result) print (len (finanl_result)) for res in finanl_result: print (res.replace('[' ,'' ).replace(']' ,'' )) 下文以翻譯類任務為例,采用基于transformer架構的MarianMT模型,MarianMTModel_zh2en中文到英文的模型參數。
完整代碼如下
import pandas as pdfrom datasets import load_datasetfrom transformers import AutoTokenizer, AutoModelForSeq2SeqLMimport osos.environ['CUDA_VISIBLE_DEVICES' ] = '1' from tqdm import tqdmfrom torch.utils.data import DataLoaderimport torchfrom data_reader.dataReader_zh2en import DataReader if __name__ == '__main__' : tokenizer = AutoTokenizer.from_pretrained("./pretrained_models/MarianMTModel_zh2en" ) model = AutoModelForSeq2SeqLM.from_pretrained("./pretrained_models/MarianMTModel_zh2en" ) dataset = DataReader(tokenizer, filepath='data/test_sample.csv' ) test_dataloader = DataLoader(dataset=dataset,batch_size=4 ) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) finanl_result = [] for batch in tqdm(test_dataloader,desc='translation prediction' ): for k, v in batch.items(): batch[k] = v.to(device) batch = {'input_ids' : batch['input_ids' ], 'attention_mask' : batch['attention_mask' ]} greedy_translation = model.generate(**batch,num_return_sequences = 1 ) greedy_batch_result = tokenizer.batch_decode(greedy_translation, skip_special_tokens=True ) finanl_result.append(greedy_batch_result) beam_translation = model.generate(**batch, num_return_sequences=1 , num_beams=5 ) beam_batch_result = tokenizer.batch_decode(beam_translation, skip_special_tokens=True ) finanl_result.append(beam_batch_result) sample_translation = model.generate(**batch, do_sample=True , num_return_sequences=1 ) sample_batch_result = tokenizer.batch_decode(sample_translation, skip_special_tokens=True ) finanl_result.append(sample_batch_result) topk_translation = model.generate(**batch, top_k=5 , num_return_sequences=1 ) topk_batch_result = tokenizer.batch_decode(topk_translation, skip_special_tokens=True ) finanl_result.append(topk_batch_result) topp_translation = model.generate(**batch, top_p=0.92 , num_return_sequences=1 ) topp_batch_result = tokenizer.batch_decode(topp_translation, skip_special_tokens=True ) finanl_result.append(topp_batch_result) topktopp_translation = model.generate(**batch, top_k=5 , top_p=0.92 , num_return_sequences=1 ) topktopp_batch_result = tokenizer.batch_decode(topktopp_translation, skip_special_tokens=True ) finanl_result.append(topktopp_batch_result) beamtopktopp_translation = model.generate(**batch, top_k=5 , top_p=0.92 , num_return_sequences=1 , num_beams=5 ) beamtopktopp_batch_result = tokenizer.batch_decode(beamtopktopp_translation, skip_special_tokens=True ) finanl_result.append(beamtopktopp_batch_result) decodes_policys = ['greedy search' ,'beam_search' ,'sampling' ,'top-k' ,'top-p' ,'top-k和top-p' ,'top-k和top-p+beam_search' ] test_sample = ['【由富氏隱孢子蟲引起的皮膚真菌病】。' ,'[十二指腸轉換手術中的減肥手術:體重變化和相關的營養缺乏]。' ,'[宮腔鏡研究數字圖像的觀察者間診斷協議]。' ] print (len (finanl_result)) for i in range (3 ): print (test_sample[i]) for ele,de_ty in zip (finanl_result,decodes_policys): print (ele[i].replace('[' ,'' ).replace(']' ,'' )) print ('*' *100 ) 翻譯src文本
【由富氏隱孢子蟲引起的皮膚真菌病】。 [十二指腸轉換手術中的減肥手術:體重變化和相關的營養缺乏] 。[宮腔鏡研究數字圖像的觀察者間診斷協議] 。 不同解碼策略得到的結果對比
【由富氏隱孢子蟲引起的皮膚真菌病】。 Skin fungi caused by Fung's Invisible Spores. Skin fungus disease caused by Fung's Invisible Spores. Skin fungi caused by Fung's spores. Skin fungi caused by Fung's Invisible Spores. Skin fungi caused by Fung's Invisible Spores. Skin fungi caused by Fung's Invisible Spores. Skin fungus disease caused by Fung's Invisible Spores. **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** [十二指腸轉換手術中的減肥手術:體重變化和相關的營養缺乏]。 Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies. Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies. Liith finger intestinal conversion operations with dietary loss: weight changes and associated nutritional deficiencies. Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies. Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies. Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies. Twelve reference to fertility reduction in intestinal conversion operations: changes in body weight and associated nutritional deficiencies. **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** **** [宮腔鏡研究數字圖像的觀察者間診斷協議]。 Observer-to-observer protocol for the study of digital images in the court cavity mirrors. Observer-to-observer protocol for the study of digital images in the court cavity mirrors. Observatorial protocol for the study of digital images in the uterine cavity mirror. Observer-to-observer protocol for the study of digital images in the court cavity mirrors. Observer-to-observer protocol for the study of digital images in the court cavity mirrors. Observer-to-observer protocol for the study of digital images in the court cavity mirrors. Observer-to-observer protocol for the study of digital images in the court cavity mirrors. 翻譯任務來看結果差異不是很大,不過也有一些差異。
參考文獻
How to generate text: using different decoding methods for language generation with Transformers
Nucleus Sampling與文本生成中的不同解碼策略比較
Seq2Seq解碼策略-概念
全面了解Beam Search
The Curious Case of Neural Text Degeneration
總結
以上是生活随笔 為你收集整理的浅谈文本生成或者文本翻译解码策略《转》 的全部內容,希望文章能夠幫你解決所遇到的問題。
如果覺得生活随笔 網站內容還不錯,歡迎將生活随笔 推薦給好友。