KBQA-Bert学习记录-构建BERT-CRF模型
生活随笔
收集整理的這篇文章主要介紹了
KBQA-Bert学习记录-构建BERT-CRF模型
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
目錄
1.__init__方法
2.forward方法
將bert和crf模型結合起來,簡單來說就是,設置好Bert模型,以及參數,得到的輸出結果給CRF模型即可。
1.__init__方法
這里面主要是bert的參數的定義及導入,還有bert模型的導入。
MODEL_NAME = 'bert-base-chinese-model.bin' CONFIG_NAME = 'bert-base-chinese-config.json' VOB_NAME = 'bert-base-chinese-vocab.txt'class BertCrf(nn.Module):def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:self.batch_first = batch_first# 模型配置文件、模型預訓練參數文件判斷if not os.path.exists(config_name):raise ValueError("未找到模型配置文件 '{}'".format(config_name))else:self.config_name = config_nameif model_name is not None:if not os.path.exists(model_name):raise ValueError("未找到模型預訓練參數文件 '{}'".format(model_name))else:self.model_name = model_nameelse:self.model_name = Noneif num_tags <= 0:raise ValueError(f'invalid number of tags: {num_tags}')super().__init__()# 配置bert的config文件self.bert_config = BertConfig.from_pretrained(self.config_name)self.bert_config.num_labels = num_tagsself.model_kwargs = {'config': self.bert_config}# 如果模型不存在if self.model_name is not None:self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)else:self.bertModel = BertForTokenClassification(self.bert_config)self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)2.forward方法
輸出的結果,經過處理后,輸入CRF函數,返回loss即可。
def forward(self, input_ids: torch.Tensor,tags: torch.Tensor = None,attention_mask: Optional[torch.ByteTensor] = None,token_type_ids=torch.Tensor,decode:bool = True,reduction: str = 'mean')->List:emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]# 去掉開頭的[CLS]以及結尾,結尾可能有兩種情況:1、<pad> 2、[SEP]new_emissions = emissions[:, 1:-1]new_mask = attention_mask[:, 2:].bool()# tags為None, 是預測過程,不能求lossif tags is None:loss = Nonepasselse:new_tags = tags[:, 1:-1]loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)if decode:tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)return [loss, tag_list]return [loss]總結
以上是生活随笔為你收集整理的KBQA-Bert学习记录-构建BERT-CRF模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python 拼音输入法_用Python
- 下一篇: 高数篇:高等数学全目录