AVOD-代码阅读理解系列(一)
生活随笔
收集整理的這篇文章主要介紹了
AVOD-代码阅读理解系列(一)
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
AVOD-代碼理解
代碼源碼鏈接:https://github.com/kujason/avod
論問鏈接:https://arxiv.org/abs/1712.02294
本系列博客用于記錄學習AVOD代碼,其代碼注釋是本人自己寫的,本人是個python新手,很多地方的不對之處歡迎各位指正.整個博客系列全是從pycharm上直接復制下來的,可能不大方便看.
1. run_training.py.整個程序的訓練開始部分.主要就是讀取訓練的config文件,對訓練進行相應設置.這個config文件是在avod/configs文件下,在里面可以看到有好幾個config設置,在實際訓練時我們會選擇其中的一個
#coding=utf-8 """Detection model trainer.This runs the DetectionModel trainer. """import argparse import osimport tensorflow as tfimport avod import avod.builders.config_builder_util as config_builder from avod.builders.dataset_builder import DatasetBuilder from avod.core.models.avod_model import AvodModel from avod.core.models.rpn_model import RpnModel from avod.core import trainertf.logging.set_verbosity(tf.logging.ERROR)def train(model_config, train_config, dataset_config):#一堆操作!!!!讀取config文件里面的詳細內容dataset = DatasetBuilder.build_kitti_dataset(dataset_config,use_defaults=False)train_val_test = 'train'#avodmodel_name = model_config.model_namewith tf.Graph().as_default():if model_name == 'rpn_model':model = RpnModel(model_config,train_val_test=train_val_test,dataset=dataset)elif model_name == 'avod_model':#avod_model,train,dataset.也就是avod_model的相關設置model = AvodModel(model_config,train_val_test=train_val_test,dataset=dataset)else:raise ValueError('Invalid model_name')#avod/core下面.下接trainer.train部分trainer.train(model, train_config)#程序開始的地方 def main(_):parser = argparse.ArgumentParser()# Defaults#訓練設置.# split() 通過指定分隔符對字符串進行切片,如果參數 num 有指定值,則僅分隔 num 個子字符串default_pipeline_config_path = avod.root_dir() + \'/configs/avod_cars_example.config'default_data_split = 'train'default_device = '1'#這是一些可以終端設置的地方,在訓練是你需要在終端指定,如果不進行指定,相應內容就會直接選擇默認的設置parser.add_argument('--pipeline_config',type=str,dest='pipeline_config_path',default=default_pipeline_config_path,help='Path to the pipeline config')parser.add_argument('--data_split',type=str,dest='data_split',default=default_data_split,help='Data split for training')parser.add_argument('--device',type=str,dest='device',default=default_device,help='CUDA device id')args = parser.parse_args()# Parse pipeline config#avod_cars_example.config#上面一個的訓練效果不是很好.用pyramid_cars_with_aug_examplemodel_config, train_config, _, dataset_config = \config_builder.get_configs_from_pipeline_file(args.pipeline_config_path, is_training=True)# Overwrite data split#train/valdataset_config.data_split = args.data_split# Set CUDA device idos.environ['CUDA_VISIBLE_DEVICES'] = args.devicetrain(model_config, train_config, dataset_config)if __name__ == '__main__':tf.app.run()2.trainer.py 是整個訓練的真正開始部分.
#coding=utf-8 """Detection model trainer.This file provides a generic training method to train a DetectionModel. """ import datetime import os import tensorflow as tf import timefrom avod.builders import optimizer_builder from avod.core import trainer_utils from avod.core import summary_utilsslim = tf.contrib.slimdef train(model, train_config):"""Training function for detection models.Args:model: The detection model object.train_config: a train_*pb2 protobuf.training i.e. loading RPN weights onto AVOD model."""model = modeltrain_config = train_config# Get model configurationsmodel_config = model.model_config# Create a variable tensor to hold the global step#創建變量張量以保持全局步驟創建變量張量以保持全局步驟global_step_tensor = tf.Variable(0, trainable=False, name='global_step')############################## Get training configurations##############################120000max_iterations = train_config.max_iterations#10summary_interval = train_config.summary_interval#1000checkpoint_interval = \train_config.checkpoint_interval#10000max_checkpoints = train_config.max_checkpoints_to_keep#data/output/avod_cars_examplepaths_config = model_config.paths_config#記錄log的文件logdir = paths_config.logdirif not os.path.exists(logdir):os.makedirs(logdir)checkpoint_dir = paths_config.checkpoint_dirif not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)checkpoint_path = checkpoint_dir + '/' + \model_config.checkpoint_nameglobal_summaries = set([])# The model should return a dictionary of predictions#avod_model/build'''start'''#這是prediction部分.其中主要過程是先直接進入avod_model.py,再#avod_model.py的build部分就有來自rpn_model的預測輸入.prediction_dict = model.build()#false!/訓練時我發現是設置為falsesummary_histograms = train_config.summary_histogramssummary_img_images = train_config.summary_img_imagessummary_bev_images = train_config.summary_bev_images############################### Setup loss##############################losses_dict, total_loss = model.loss(prediction_dict)# Optimizer# adam_optimizertraining_optimizer = optimizer_builder.build(train_config.optimizer,global_summaries,global_step_tensor)# Create the train opwith tf.variable_scope('train_op'):train_op = slim.learning.create_train_op(total_loss,training_optimizer,clip_gradient_norm=1.0,global_step=global_step_tensor)# Save checkpoints regularly.saver = tf.train.Saver(max_to_keep=max_checkpoints,pad_step_number=True)# Add the result of the train_op to the summarytf.summary.scalar("training_loss", train_op)# Add maximum memory usage summary op# This op can only be run on device with gpu# so it's skipped on travisis_travis = 'TRAVIS' in os.environif not is_travis:# tf.summary.scalar('bytes_in_use',# tf.contrib.memory_stats.BytesInUse())tf.summary.scalar('max_bytes',tf.contrib.memory_stats.MaxBytesInUse())summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))summary_merged = summary_utils.summaries_to_keep(summaries,global_summaries,histograms=summary_histograms,input_imgs=summary_img_images,input_bevs=summary_bev_images)#true!allow_gpu_mem_growth = train_config.allow_gpu_mem_growthif allow_gpu_mem_growth:# GPU memory configconfig = tf.ConfigProto()config.gpu_options.allow_growth = allow_gpu_mem_growthsess = tf.Session(config=config)else:sess = tf.Session()# Create unique folder name using datetime for summary writerdatetime_str = str(datetime.datetime.now())logdir = logdir + '/train'train_writer = tf.summary.FileWriter(logdir + '/' + datetime_str,sess.graph)# Create init opinit = tf.global_variables_initializer()# Continue from last saved checkpoint#true!if not train_config.overwrite_checkpoints:trainer_utils.load_checkpoints(checkpoint_dir,saver)if len(saver.last_checkpoints) > 0:checkpoint_to_restore = saver.last_checkpoints[-1]saver.restore(sess, checkpoint_to_restore)else:# Initialize the variablessess.run(init)else:# Initialize the variablessess.run(init)# Read the global step if restoredglobal_step = tf.train.global_step(sess,global_step_tensor)print('Starting from step {} / {}'.format(global_step, max_iterations))# Main Training Looplast_time = time.time()for step in range(global_step, max_iterations + 1):# Save checkpoint#1000if step % checkpoint_interval == 0:global_step = tf.train.global_step(sess,global_step_tensor)saver.save(sess,save_path=checkpoint_path,global_step=global_step)print('Step {} / {}, Checkpoint saved to {}-{:08d}'.format(step, max_iterations,checkpoint_path, global_step))# Create feed_dict for inferencing#輸入feed_dict = model.create_feed_dict()# Write summaries and train op#10if step % summary_interval == 0:current_time = time.time()time_elapsed = current_time - last_timelast_time = current_time#預測部分的整個開頭train_op_loss, summary_out = sess.run([train_op, summary_merged], feed_dict=feed_dict)print('Step {}, Total Loss {:0.3f}, Time Elapsed {:0.3f} s'.format(step, train_op_loss, time_elapsed))train_writer.add_summary(summary_out, step)else:# Run the train op onlysess.run(train_op, feed_dict)# Close the summary writerstrain_writer.close()開篇主要是介紹一下整個程序的開頭部分,不會涉及大量的代碼解釋,其間的大部分的內容都是代碼原部分的英語注釋,比較好懂.接下來會根據網絡的結構分篇介紹!
ps:轉載請注明出處!不必問我,注明出處就行!請尊重別人的勞動成果,雖然只是打打字,但還是很麻煩的.
總結
以上是生活随笔為你收集整理的AVOD-代码阅读理解系列(一)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python实现黑客帝国动画效果
- 下一篇: 郭天祥的10天学会51单片机_第一节