【NLP】NLP实战篇之bert源码阅读(run_classifier)
本文主要會(huì)閱讀bert源碼
(https://github.com/google-research/bert )中run_classifier.py文件,已完成modeling.py、optimization.py、run_pretraining.py、tokenization.py、create_pretraining_data.py、extract_feature.py文件的源碼閱讀,后續(xù)會(huì)陸續(xù)閱讀bert的理解任務(wù)訓(xùn)練等源碼。本文介紹了run_classifier.py中的主要內(nèi)容,包括不同分類任務(wù)的數(shù)據(jù)讀取,用于分類的bert模型結(jié)構(gòu),和整體的訓(xùn)練流程。代碼中還涉及很多其他內(nèi)容,如運(yùn)行參數(shù),特征轉(zhuǎn)為tfrecord文件等等,由于在之前的閱讀中,出現(xiàn)過(guò)非常相似的內(nèi)容,所以這里不再重復(fù)。
run_classifier.py的全部代碼以及中文注釋可參考
https://github.com/wellinxu/nlp_store/blob/master/read_source/bert/run_classifier.py。
實(shí)戰(zhàn)系列篇章中主要會(huì)分享,解決實(shí)際問(wèn)題時(shí)的過(guò)程、遇到的問(wèn)題或者使用的工具等等。如問(wèn)題分解、bug排查、模型部署等等。相關(guān)代碼實(shí)現(xiàn)開源在:https://github.com/wellinxu/nlp_store ,。
分類任務(wù)
句對(duì)分類
單句分類
模型結(jié)構(gòu)
model_fn
main函數(shù)
其他
分類任務(wù)
源碼中,bert能處理的文本分類任務(wù)可簡(jiǎn)單分為兩種:句對(duì)分類任務(wù)(比如文本匹配/文本蘊(yùn)含等)和單句分類任務(wù)(比如情感分類/長(zhǎng)文本分類等)。代碼中涉及好幾個(gè)任務(wù)的數(shù)據(jù)讀取,不過(guò)因?yàn)榇笸‘?#xff0c;本文只分別講述一個(gè)示例,其他任務(wù)的相關(guān)代碼請(qǐng)參考原代碼。
不管是哪種分類任務(wù),都是將每個(gè)樣本轉(zhuǎn)化為一個(gè)InputExample類,如下面代碼所示,其中包含樣本的id,第一句文本,第二句文本(單句分類時(shí)為空)以及標(biāo)簽文本。
句對(duì)分類
關(guān)于句對(duì)分類任務(wù),我們主要分析MultiNLI數(shù)據(jù)集,這是一個(gè)文本蘊(yùn)含任務(wù),需要判斷前后兩句文本是對(duì)立、蘊(yùn)含還是中立關(guān)系。下面的這段代碼則是讀取該任務(wù)的代碼,邏輯很簡(jiǎn)單,主要就是將每個(gè)樣本的第一句、第二句和標(biāo)簽分別取出,然后構(gòu)建InputExample。
class?MnliProcessor(DataProcessor):"""處理MultiNLI數(shù)據(jù)集(GLUE版本)."""def?get_train_examples(self,?data_dir):"""See?base?class."""return?self._create_examples(self._read_tsv(os.path.join(data_dir,?"train.tsv")),?"train")def?get_dev_examples(self,?data_dir):return?self._create_examples(self._read_tsv(os.path.join(data_dir,?"dev_matched.tsv")),"dev_matched")def?get_test_examples(self,?data_dir):return?self._create_examples(self._read_tsv(os.path.join(data_dir,?"test_matched.tsv")),?"test")def?get_labels(self):return?["contradiction",?"entailment",?"neutral"]def?_create_examples(self,?lines,?set_type):"""生成樣本數(shù)據(jù)"""examples?=?[]for?(i,?line)?in?enumerate(lines):if?i?==?0:continueguid?=?"%s-%s"?%?(set_type,?tokenization.convert_to_unicode(line[0]))text_a?=?tokenization.convert_to_unicode(line[8])text_b?=?tokenization.convert_to_unicode(line[9])if?set_type?==?"test":label?=?"contradiction"else:label?=?tokenization.convert_to_unicode(line[-1])examples.append(InputExample(guid=guid,?text_a=text_a,?text_b=text_b,?label=label))return?examples單句分類
關(guān)于單句分類任務(wù),我們分析CoLA數(shù)據(jù)集,其是一個(gè)文本二分類問(wèn)題。如下面代碼所示,其邏輯更加簡(jiǎn)單,主要就是將單句文本與標(biāo)簽提取出來(lái),然后構(gòu)建InputExample。
class?ColaProcessor(DataProcessor):"""處理CoLA數(shù)據(jù)集(GLUE版本)."""def?get_train_examples(self,?data_dir):return?self._create_examples(self._read_tsv(os.path.join(data_dir,?"train.tsv")),?"train")def?get_dev_examples(self,?data_dir):return?self._create_examples(self._read_tsv(os.path.join(data_dir,?"dev.tsv")),?"dev")def?get_test_examples(self,?data_dir):return?self._create_examples(self._read_tsv(os.path.join(data_dir,?"test.tsv")),?"test")def?get_labels(self):return?["0",?"1"]def?_create_examples(self,?lines,?set_type):"""生成樣本數(shù)據(jù)"""examples?=?[]for?(i,?line)?in?enumerate(lines):#?測(cè)試集有header行,所以要跳過(guò)if?set_type?==?"test"?and?i?==?0:continueguid?=?"%s-%s"?%?(set_type,?i)if?set_type?==?"test":text_a?=?tokenization.convert_to_unicode(line[1])label?=?"0"else:text_a?=?tokenization.convert_to_unicode(line[3])label?=?tokenization.convert_to_unicode(line[1])examples.append(InputExample(guid=guid,?text_a=text_a,?text_b=None,?label=label))return?examples模型結(jié)構(gòu)
bert在做文本分類時(shí)的模型結(jié)構(gòu)比較簡(jiǎn)單,直接用pooled層的結(jié)果接一層全連接層+softmax。如下圖所示,句對(duì)分類任務(wù)與單句分類任務(wù)的模型結(jié)構(gòu)基本一致,只是開始的輸入略有差異。
模型的代碼構(gòu)建如下所示,建立bert模型之后,獲取pooled層輸出(bertd 模型結(jié)構(gòu)和pooled層輸出可參考modeling.py),然后接上全連接層計(jì)算softmax,最后計(jì)算交叉熵loss,用來(lái)訓(xùn)練。
model_fn
上面已經(jīng)得到了模型結(jié)構(gòu),但在任務(wù)運(yùn)行過(guò)程中,一般會(huì)分為訓(xùn)練階段、驗(yàn)證階段和預(yù)測(cè)階段,每個(gè)階段需要計(jì)算的算子和返回的結(jié)果略有不同,所以就有了以下代碼。其主要邏輯是,將樣本數(shù)據(jù)輸入給bert模型之后,根據(jù)階段的不同,分別獲取訓(xùn)練階段的優(yōu)化操作、驗(yàn)證階段的評(píng)估操作和預(yù)測(cè)階段的直接結(jié)果。
def?model_fn_builder(bert_config,?num_labels,?init_checkpoint,?learning_rate,num_train_steps,?num_warmup_steps,?use_tpu,use_one_hot_embeddings):"""返回給TPUEstimator使用的模型函數(shù)-model_fn可以參考run_pretraining.py中的model_fn_builder方法"""def?model_fn(features,?labels,?mode,?params):??#?pylint:?disable=unused-argument"""待返回的模型函數(shù),model_fn"""tf.logging.info("***?Features?***")for?name?in?sorted(features.keys()):tf.logging.info("??name?=?%s,?shape?=?%s"?%?(name,?features[name].shape))#?樣本輸入數(shù)據(jù)input_ids?=?features["input_ids"]input_mask?=?features["input_mask"]segment_ids?=?features["segment_ids"]label_ids?=?features["label_ids"]is_real_example?=?Noneif?"is_real_example"?in?features:is_real_example?=?tf.cast(features["is_real_example"],?dtype=tf.float32)else:????#?tpu訓(xùn)練需要固定尺寸,所以在某些step中樣本不夠的時(shí)候需要構(gòu)建假樣本is_real_example?=?tf.ones(tf.shape(label_ids),?dtype=tf.float32)is_training?=?(mode?==?tf.estimator.ModeKeys.TRAIN)#?將樣本輸入模型得到結(jié)果(total_loss,?per_example_loss,?logits,?probabilities)?=?create_model(bert_config,?is_training,?input_ids,?input_mask,?segment_ids,?label_ids,num_labels,?use_one_hot_embeddings)tvars?=?tf.trainable_variables()initialized_variable_names?=?{}scaffold_fn?=?Noneif?init_checkpoint:???#?是否只是初始化模型參數(shù)(assignment_map,?initialized_variable_names)?=?modeling.get_assignment_map_from_checkpoint(tvars,?init_checkpoint)if?use_tpu:def?tpu_scaffold():tf.train.init_from_checkpoint(init_checkpoint,?assignment_map)return?tf.train.Scaffold()scaffold_fn?=?tpu_scaffoldelse:tf.train.init_from_checkpoint(init_checkpoint,?assignment_map)tf.logging.info("****?Trainable?Variables?****")for?var?in?tvars:init_string?=?""if?var.name?in?initialized_variable_names:init_string?=?",?*INIT_FROM_CKPT*"tf.logging.info("??name?=?%s,?shape?=?%s%s",?var.name,?var.shape,init_string)output_spec?=?Noneif?mode?==?tf.estimator.ModeKeys.TRAIN:????#?訓(xùn)練模式train_op?=?optimization.create_optimizer(total_loss,?learning_rate,?num_train_steps,?num_warmup_steps,?use_tpu)output_spec?=?tf.contrib.tpu.TPUEstimatorSpec(mode=mode,loss=total_loss,train_op=train_op,scaffold_fn=scaffold_fn)elif?mode?==?tf.estimator.ModeKeys.EVAL:???#?評(píng)估模式#?評(píng)價(jià)函數(shù),會(huì)計(jì)算評(píng)估數(shù)據(jù)集結(jié)果的準(zhǔn)確性和lossdef?metric_fn(per_example_loss,?label_ids,?logits,?is_real_example):predictions?=?tf.argmax(logits,?axis=-1,?output_type=tf.int32)accuracy?=?tf.metrics.accuracy(labels=label_ids,?predictions=predictions,?weights=is_real_example)loss?=?tf.metrics.mean(values=per_example_loss,?weights=is_real_example)return?{"eval_accuracy":?accuracy,"eval_loss":?loss,}#?評(píng)估模式會(huì)返回評(píng)估結(jié)果eval_metrics?=?(metric_fn,[per_example_loss,?label_ids,?logits,?is_real_example])output_spec?=?tf.contrib.tpu.TPUEstimatorSpec(mode=mode,loss=total_loss,eval_metrics=eval_metrics,scaffold_fn=scaffold_fn)else:???#?預(yù)測(cè)模式,只要返回預(yù)測(cè)值output_spec?=?tf.contrib.tpu.TPUEstimatorSpec(mode=mode,predictions={"probabilities":?probabilities},scaffold_fn=scaffold_fn)return?output_specreturn?model_fnmain函數(shù)
在知道輸入讀取與模型結(jié)構(gòu)之后,我們來(lái)看下分類任務(wù)的主體結(jié)構(gòu)main函數(shù)。其主要邏輯如下:
檢查并測(cè)試bert相關(guān)參數(shù)
根據(jù)任務(wù)名稱獲取數(shù)據(jù)處理類
設(shè)置訓(xùn)練參數(shù),構(gòu)建bert模型與estimator
如果執(zhí)行訓(xùn)練階段:
將訓(xùn)練樣本保存為tfrecord格式
將訓(xùn)練樣本轉(zhuǎn)換為訓(xùn)練輸入函數(shù)
訓(xùn)練模型
如果執(zhí)行驗(yàn)證階段:
將驗(yàn)證樣本保存為tfrecord格式
將驗(yàn)證樣本轉(zhuǎn)換為驗(yàn)證輸入函數(shù)
驗(yàn)證模型
將評(píng)估結(jié)果寫入文件
如果執(zhí)行預(yù)測(cè)階段:
將預(yù)測(cè)樣本保存為tfrecord格式
將預(yù)測(cè)樣本轉(zhuǎn)化為預(yù)測(cè)輸入函數(shù)
模型預(yù)測(cè)
將預(yù)測(cè)結(jié)果寫入文件
其中將數(shù)據(jù)轉(zhuǎn)化為tfrecord格式,是file_based_convert_examples_to_features函數(shù)實(shí)現(xiàn)的,可參考create_pretraining_data.py中的write_instance_to_example_files方法,不再贅述;而轉(zhuǎn)為輸入函數(shù),則是file_based_input_fn_builder函數(shù)實(shí)現(xiàn)的,可以參考run_pretrainin.py中的input_fn_builder方法,也不贅述。整體的main代碼如下:
def?main(_):tf.logging.set_verbosity(tf.logging.INFO)????#?設(shè)置日志等級(jí)#?數(shù)據(jù)處理類的映射processors?=?{"cola":?ColaProcessor,"mnli":?MnliProcessor,"mrpc":?MrpcProcessor,"xnli":?XnliProcessor,}#?校驗(yàn)?zāi)P蛥?shù)tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,FLAGS.init_checkpoint)if?not?FLAGS.do_train?and?not?FLAGS.do_eval?and?not?FLAGS.do_predict:raise?ValueError("At?least?one?of?`do_train`,?`do_eval`?or?`do_predict'?must?be?True.")bert_config?=?modeling.BertConfig.from_json_file(FLAGS.bert_config_file)????#?獲取bert配置if?FLAGS.max_seq_length?>?bert_config.max_position_embeddings:raise?ValueError("Cannot?use?sequence?length?%d?because?the?BERT?model?""was?only?trained?up?to?sequence?length?%d"?%(FLAGS.max_seq_length,?bert_config.max_position_embeddings))tf.gfile.MakeDirs(FLAGS.output_dir)task_name?=?FLAGS.task_name.lower()????#?獲取任務(wù)名稱if?task_name?not?in?processors:raise?ValueError("Task?not?found:?%s"?%?(task_name))processor?=?processors[task_name]()????#?獲取數(shù)據(jù)處理類label_list?=?processor.get_labels()????#?獲取標(biāo)簽集合#?初始化token切分器tokenizer?=?tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,?do_lower_case=FLAGS.do_lower_case)#?tpu相關(guān)tpu_cluster_resolver?=?Noneif?FLAGS.use_tpu?and?FLAGS.tpu_name:tpu_cluster_resolver?=?tf.contrib.cluster_resolver.TPUClusterResolver(FLAGS.tpu_name,?zone=FLAGS.tpu_zone,?project=FLAGS.gcp_project)#?tpu相關(guān)is_per_host?=?tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2run_config?=?tf.contrib.tpu.RunConfig(cluster=tpu_cluster_resolver,master=FLAGS.master,model_dir=FLAGS.output_dir,save_checkpoints_steps=FLAGS.save_checkpoints_steps,tpu_config=tf.contrib.tpu.TPUConfig(iterations_per_loop=FLAGS.iterations_per_loop,num_shards=FLAGS.num_tpu_cores,per_host_input_for_training=is_per_host))train_examples?=?None????#?訓(xùn)練樣本num_train_steps?=?None????#?訓(xùn)練步數(shù)num_warmup_steps?=?None????#?warmup步數(shù)if?FLAGS.do_train:train_examples?=?processor.get_train_examples(FLAGS.data_dir)num_train_steps?=?int(len(train_examples)?/?FLAGS.train_batch_size?*?FLAGS.num_train_epochs)num_warmup_steps?=?int(num_train_steps?*?FLAGS.warmup_proportion)#?模型函數(shù),輸入到輸出中間的結(jié)構(gòu)定義model_fn?=?model_fn_builder(bert_config=bert_config,num_labels=len(label_list),init_checkpoint=FLAGS.init_checkpoint,learning_rate=FLAGS.learning_rate,num_train_steps=num_train_steps,num_warmup_steps=num_warmup_steps,use_tpu=FLAGS.use_tpu,use_one_hot_embeddings=FLAGS.use_tpu)#?如果tpu不可用,則會(huì)退化成cpu或者gpu版本estimator?=?tf.contrib.tpu.TPUEstimator(use_tpu=FLAGS.use_tpu,model_fn=model_fn,config=run_config,train_batch_size=FLAGS.train_batch_size,eval_batch_size=FLAGS.eval_batch_size,predict_batch_size=FLAGS.predict_batch_size)#?進(jìn)行訓(xùn)練if?FLAGS.do_train:train_file?=?os.path.join(FLAGS.output_dir,?"train.tf_record")#?將輸入數(shù)據(jù)轉(zhuǎn)換為tfrecord格式,并保存file_based_convert_examples_to_features(train_examples,?label_list,?FLAGS.max_seq_length,?tokenizer,?train_file)tf.logging.info("*****?Running?training?*****")tf.logging.info("??Num?examples?=?%d",?len(train_examples))tf.logging.info("??Batch?size?=?%d",?FLAGS.train_batch_size)tf.logging.info("??Num?steps?=?%d",?num_train_steps)#?訓(xùn)練的輸入函數(shù),產(chǎn)生訓(xùn)練輸入樣本train_input_fn?=?file_based_input_fn_builder(input_file=train_file,seq_length=FLAGS.max_seq_length,is_training=True,drop_remainder=True)#?根據(jù)輸入函數(shù)與模型函數(shù)進(jìn)行訓(xùn)練模型estimator.train(input_fn=train_input_fn,?max_steps=num_train_steps)#?進(jìn)行評(píng)估if?FLAGS.do_eval:eval_examples?=?processor.get_dev_examples(FLAGS.data_dir)num_actual_eval_examples?=?len(eval_examples)if?FLAGS.use_tpu:#?TPU需要固定大小的batch,添加加樣本補(bǔ)足while?len(eval_examples)?%?FLAGS.eval_batch_size?!=?0:eval_examples.append(PaddingInputExample())eval_file?=?os.path.join(FLAGS.output_dir,?"eval.tf_record")#?將輸入數(shù)據(jù)轉(zhuǎn)換為tfrecord格式,并保存file_based_convert_examples_to_features(eval_examples,?label_list,?FLAGS.max_seq_length,?tokenizer,?eval_file)tf.logging.info("*****?Running?evaluation?*****")tf.logging.info("??Num?examples?=?%d?(%d?actual,?%d?padding)",len(eval_examples),?num_actual_eval_examples,len(eval_examples)?-?num_actual_eval_examples)tf.logging.info("??Batch?size?=?%d",?FLAGS.eval_batch_size)eval_steps?=?None#?使用TPU時(shí),需要知道具體運(yùn)行步數(shù)if?FLAGS.use_tpu:assert?len(eval_examples)?%?FLAGS.eval_batch_size?==?0eval_steps?=?int(len(eval_examples)?//?FLAGS.eval_batch_size)eval_drop_remainder?=?True?if?FLAGS.use_tpu?else?False#?評(píng)估的輸入函數(shù),產(chǎn)生評(píng)估輸入樣本eval_input_fn?=?file_based_input_fn_builder(input_file=eval_file,seq_length=FLAGS.max_seq_length,is_training=False,drop_remainder=eval_drop_remainder)#?根據(jù)輸入函數(shù)與模型函數(shù)進(jìn)行模型評(píng)估result?=?estimator.evaluate(input_fn=eval_input_fn,?steps=eval_steps)#?評(píng)估結(jié)果寫入文件output_eval_file?=?os.path.join(FLAGS.output_dir,?"eval_results.txt")with?tf.gfile.GFile(output_eval_file,?"w")?as?writer:tf.logging.info("*****?Eval?results?*****")for?key?in?sorted(result.keys()):tf.logging.info("??%s?=?%s",?key,?str(result[key]))writer.write("%s?=?%s\n"?%?(key,?str(result[key])))#?進(jìn)行預(yù)測(cè)if?FLAGS.do_predict:predict_examples?=?processor.get_test_examples(FLAGS.data_dir)num_actual_predict_examples?=?len(predict_examples)if?FLAGS.use_tpu:#??TPU需要固定大小的batch,添加加樣本補(bǔ)足while?len(predict_examples)?%?FLAGS.predict_batch_size?!=?0:predict_examples.append(PaddingInputExample())predict_file?=?os.path.join(FLAGS.output_dir,?"predict.tf_record")#?將輸入數(shù)據(jù)轉(zhuǎn)換為tfrecord格式,并保存file_based_convert_examples_to_features(predict_examples,?label_list,FLAGS.max_seq_length,?tokenizer,predict_file)tf.logging.info("*****?Running?prediction*****")tf.logging.info("??Num?examples?=?%d?(%d?actual,?%d?padding)",len(predict_examples),?num_actual_predict_examples,len(predict_examples)?-?num_actual_predict_examples)tf.logging.info("??Batch?size?=?%d",?FLAGS.predict_batch_size)predict_drop_remainder?=?True?if?FLAGS.use_tpu?else?False#?預(yù)測(cè)的輸入函數(shù),產(chǎn)生預(yù)測(cè)輸入樣本predict_input_fn?=?file_based_input_fn_builder(input_file=predict_file,seq_length=FLAGS.max_seq_length,is_training=False,drop_remainder=predict_drop_remainder)#?根據(jù)輸入函數(shù)與模型函數(shù)使用模型預(yù)測(cè)result?=?estimator.predict(input_fn=predict_input_fn)#?預(yù)測(cè)結(jié)果寫入文件output_predict_file?=?os.path.join(FLAGS.output_dir,?"test_results.tsv")with?tf.gfile.GFile(output_predict_file,?"w")?as?writer:num_written_lines?=?0tf.logging.info("*****?Predict?results?*****")for?(i,?prediction)?in?enumerate(result):probabilities?=?prediction["probabilities"]if?i?>=?num_actual_predict_examples:breakoutput_line?=?"\t".join(str(class_probability)for?class_probability?in?probabilities)?+?"\n"writer.write(output_line)num_written_lines?+=?1assert?num_written_lines?==?num_actual_predict_examples其他
文本分類代碼與預(yù)訓(xùn)練和create_data代碼有很多相似代碼,這邊都不再贅述,比如convert_single_example、_truncate_seq_pair、file_based_convert_examples_to_features、file_based_input_fn_builder、input_fn_builder等函數(shù)都可以找到非常相似的代碼。而模型的運(yùn)行參數(shù)與之前也是大同小異,具體參數(shù)以及整體代碼及中文注釋,都可以參考https://github.com/wellinxu/nlp_store/blob/master/read_source/bert/run_classifier.py。
往期精彩回顧適合初學(xué)者入門人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專輯溫州大學(xué)《機(jī)器學(xué)習(xí)課程》視頻 本站qq群851320808,加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【NLP】NLP实战篇之bert源码阅读(run_classifier)的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 番茄花园win11 32位官方原版镜像文
- 下一篇: 【机器学习】基于LDA主题模型的人脸识别