Paddle 使用预训练模型 实现快递单信息抽取
生活随笔
收集整理的這篇文章主要介紹了
Paddle 使用预训练模型 实现快递单信息抽取
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
文章目錄
- 1. 導包
- 2. 數(shù)據(jù)處理
- 3. 輔助函數(shù)
- 3.1 評估函數(shù)
- 3.2 預測函數(shù)
- 3.3 預測結(jié)果解碼
- 4. 訓練
填寫快遞單據(jù)可以直接把所有信息直接粘貼進客戶端,客戶端自動識別 省市、人名、電話等信息,分類填入,然后打印出來粘貼。無須人工填寫,加快了作業(yè)效率。
learn from : https://aistudio.baidu.com/aistudio/projectdetail/1329361
通過使用預訓練模型+finetune,訓練一個快遞信息抽取模型。
1. 導包
# 快遞單信息抽取 from functools import partial # 打包函數(shù),并給定默認參數(shù) import paddle from paddlenlp.datasets import MapDataset # 自定義數(shù)據(jù)集 from paddlenlp.data import Stack, Tuple, Pad # batch化工具函數(shù) from paddlenlp.transformers import ErnieTokenizer, ErnieForTokenClassification from paddlenlp.metrics import ChunkEvaluator # 指標計算 from paddle.utils.download import get_path_from_url2. 數(shù)據(jù)處理
URL = "https://paddlenlp.bj.bcebos.com/paddlenlp/datasets/waybill.tar.gz" get_path_from_url(URL, "./") epochs = 10 batch_size = 16def load_dict(dict_path): # 讀取字典vocab = {}i = 0for line in open(dict_path, 'r', encoding='utf-8'):key = line.strip('\n')vocab[key] = ii += 1return vocab# 展示下數(shù)據(jù)格式 with open("./data/test.txt", 'r', encoding='utf-8') as f:i = 0for line in f:print(line)i += 1if i > 5:break# text_a label # # 黑龍江省雙鴨山市尖山區(qū)八馬路與東平行路交叉口北40米韋業(yè)濤18600009172 # A1-BA1-IA1-IA1-IA2-BA2-IA2-IA2-IA3-BA3-IA3-IA4-BA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IA4-IP-BP-IP-IT-BT-IT-IT-IT-IT-IT-IT-IT-IT-IT-I # A1 表示省,-B 開始, -I 內(nèi)部, P 人名, T 電話- 數(shù)據(jù)轉(zhuǎn)換函數(shù),把文字轉(zhuǎn)成數(shù)字 ids 類型
- 加載數(shù)據(jù)集
- batch化數(shù)據(jù)
3. 輔助函數(shù)
3.1 評估函數(shù)
@paddle.no_grad() def evaluate(model, metric, data_loader):model.eval()metric.reset()for input_ids, seg_ids, lens, labels in data_loader:logits = model(input_ids, seg_ids)preds = paddle.argmax(logits, axis=-1)n_infer, n_label, n_correct = metric.compute(None, lens, preds, labels)metric.update(n_infer.numpy(), n_label.numpy(), n_correct.numpy())precision, recall, f1_score = metric.accumulate()print("eval precision: %f - recall: %f - f1: %f" %(precision, recall, f1_score))model.train()3.2 預測函數(shù)
def predict(model, data_loader, ds, label_vocab):pred_list = []len_list = []for input_ids, seg_ids, lens, labels in data_loader:logits = model(input_ids, seg_ids)pred = paddle.argmax(logits, axis=-1)pred_list.append(pred.numpy())len_list.append(lens.numpy())preds = parse_decodes(ds, pred_list, len_list, label_vocab)return preds3.3 預測結(jié)果解碼
def parse_decodes(ds, decodes, lens, label_vocab):decodes = [x for batch in decodes for x in batch]lens = [x for batch in lens for x in batch]id_label = dict(zip(label_vocab.values(), label_vocab.keys()))outputs = []for idx, end in enumerate(lens):sent = ds.data[idx][0][:end]tags = [id_label[x] for x in decodes[idx][1:end]]sent_out = []tags_out = []words = ""for s, t in zip(sent, tags):if t.endswith('-B') or t == 'O':if len(words):sent_out.append(words)tags_out.append(t.split('-')[0])words = selse:words += sif len(sent_out) < len(tags_out):sent_out.append(words)outputs.append(''.join([str((s, t)) for s, t in zip(sent_out, tags_out)]))return outputs4. 訓練
# 加載預訓練模型 model = ErnieForTokenClassification.from_pretrained("ernie-1.0", num_classes=len(label_vocab)) # 指標 metric = ChunkEvaluator(label_list=label_vocab.keys(), suffix=True) # 損失函數(shù) loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=ignore_label) # 優(yōu)化器 optimizer = paddle.optimizer.AdamW(learning_rate=2e-5, parameters=model.parameters())# 訓練 step = 0 for epoch in range(epochs):for idx, (input_ids, token_type_ids, length, labels) in enumerate(train_loader):logits = model(input_ids, token_type_ids)loss = paddle.mean(loss_fn(logits, labels))loss.backward()optimizer.step()optimizer.clear_grad()step += 1print("epoch:%d - step:%d - loss: %f" % (epoch, step, loss))# 每個 epoch 評估一次evaluate(model, metric, dev_loader)# 保存模型參數(shù)paddle.save(model.state_dict(),'./ernie_result/model_%d.pdparams' % step)# 訓練完成,加載模型參數(shù) state_dict = paddle.load("./ernie_result/model_450.pdparams") model.load_dict(state_dict)# 預測 preds = predict(model, test_loader, test_ds, label_vocab) file_path = "ernie_results.txt" with open(file_path, "w", encoding="utf8") as fout:fout.write("\n".join(preds)) # 打印預測結(jié)果 print("The results have been saved in the file: %s, some examples are shown below: "% file_path) print("\n".join(preds[:10]))訓練過程:
epoch:0 - step:1 - loss: 2.788503 epoch:0 - step:2 - loss: 2.520449 epoch:0 - step:3 - loss: 2.365216 epoch:0 - step:4 - loss: 2.255839 epoch:0 - step:5 - loss: 2.108390 epoch:0 - step:6 - loss: 2.006438 ... epoch:0 - step:100 - loss: 0.045199 eval precision: 0.969141 - recall: 0.977292 - f1: 0.973199 epoch:1 - step:101 - loss: 0.026065 ... epoch:1 - step:200 - loss: 0.012335 eval precision: 0.984925 - recall: 0.989066 - f1: 0.986991 epoch:2 - step:201 - loss: 0.014337 ... epoch:2 - step:300 - loss: 0.004556 eval precision: 0.987427 - recall: 0.990749 - f1: 0.989085 epoch:3 - step:301 - loss: 0.003423 ... epoch:3 - step:400 - loss: 0.002968 eval precision: 0.987427 - recall: 0.990749 - f1: 0.989085 epoch:4 - step:401 - loss: 0.001868 ... epoch:4 - step:500 - loss: 0.016371 eval precision: 0.989933 - recall: 0.992431 - f1: 0.991180 epoch:5 - step:501 - loss: 0.006276 ... epoch:5 - step:530 - loss: 0.001634 ...一些預測結(jié)果:
The results have been saved in the file: ernie_results.txt, some examples are shown below: ('黑龍江省', 'A1')('雙鴨山市', 'A2')('尖山區(qū)', 'A3')('八馬路與東平行路交叉口北40米', 'A4')('韋業(yè)濤', 'P')('18600009172', 'T') ('廣西壯族自治區(qū)', 'A1')('桂林市', 'A2')('雁山區(qū)', 'A3')('雁山鎮(zhèn)西龍村老年活動中心', 'A4')('17610348888', 'T')('羊卓衛(wèi)', 'P') ('15652864561', 'T')('河南省', 'A1')('開封市', 'A2')('順河回族區(qū)', 'A3')('順河區(qū)公園路32號', 'A4')('趙本山', 'P') ('河北省', 'A1')('唐山市', 'A2')('玉田縣', 'A3')('無終大街159號', 'A4')('18614253058', 'T')('尚漢生', 'P') ('臺灣', 'A1')('臺中市', 'A2')('北區(qū)', 'A3')('北區(qū)錦新街18號', 'A4')('18511226708', 'T')('薊麗', 'P') ('廖梓琪', 'P')('18514743222', 'T')('湖北省', 'A1')('宜昌市', 'A2')('長陽土家族自治縣', 'A3')('賀家坪鎮(zhèn)賀家坪村一組臨河1號', 'A4') ('江蘇省', 'A1')('南通市', 'A2')('海門市', 'A3')('孝威村孝威路88號', 'A4')('18611840623', 'T')('計星儀', 'P') ('17601674746', 'T')('趙春麗', 'P')('內(nèi)蒙古自治區(qū)', 'A1')('烏蘭察布市', 'A2')('涼城縣', 'A3')('新建街', 'A4') ('云南省', 'A1')('臨滄市', 'A2')('耿馬傣族佤族自治縣', 'A3')('鑫源路法院對面', 'A4')('許貞愛', 'P')('18510566685', 'T') ('四川省', 'A1')('成都市', 'A2')('雙流區(qū)', 'A3')('東升鎮(zhèn)北倉路196號', 'A4')('耿丕嶺', 'P')('18513466161', 'T') 創(chuàng)作挑戰(zhàn)賽新人創(chuàng)作獎勵來咯,堅持創(chuàng)作打卡瓜分現(xiàn)金大獎總結(jié)
以上是生活随笔為你收集整理的Paddle 使用预训练模型 实现快递单信息抽取的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: LeetCode 2032. 至少在两个
- 下一篇: LeetCode 1776. 车队 II