Pointnet(part_seg)train.py,test.py代码随记
生活随笔
收集整理的這篇文章主要介紹了
Pointnet(part_seg)train.py,test.py代码随记
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
train.py
我將代碼全部簡化,將關(guān)鍵步驟全部列出
hdf5_data_dir = 數(shù)據(jù)集路徑 #讀取數(shù)據(jù)集的路徑創(chuàng)建os.mkdir(train_result) #創(chuàng)建train_result文件夾color_map_file = part_color_mapping.json #讀取顏色json文件路徑,一共50類color_map = json.load() #讀取.json文件內(nèi)容讀取overallid_to_catid_partid.json 列表形式 #讀取物體零件編號training_file_list = train_hdf5 路徑 testing_file_list = test_hdf5 路徑model_storage_path = 'trained_models' #在train_result下創(chuàng)建了一個trained_models文件夾,#用于存放訓(xùn)練好的模型 創(chuàng)建logs日志文件夾 創(chuàng)建summaries文件夾,可視化def train():pointclouds_ph = (32,2048,3)input_label_ph = (32,16)label_ph = (32)seg_ph = (32,2048) """以上是導(dǎo)入輸入數(shù)據(jù)的占位符"""batch = 初始化變量0learning_rate = 指數(shù)衰減學(xué)習(xí)率bn_decay = 批標(biāo)準(zhǔn)化衰減率labels_pred , seg_pred, end_points = model.get_model( (32,2048,3),(32,16),...)#模型訓(xùn)練loss = get_loss()#計(jì)算損失train_variables = tf.trainable_variables() #可訓(xùn)練參數(shù)trainer = tf.train.AdamOptimizer(learning_rate) #優(yōu)化器優(yōu)化train_op = trainer.minimize(loss, var_list=train_variables, global_step=batch)#梯度優(yōu)化,更新var_list最大程度減少損失saver = tf.train.Saver() #保存和加載模型init = tf.global_variables_initializer() #全局變量初始化sess.run(init) #圖結(jié)構(gòu)創(chuàng)建好了,開始會話 for epoch in range(training_epoches): #訓(xùn)練次數(shù)eval_one_epoch(epoch)train_file_idx = np.arange(0,6)打亂順序train_one_epoch(train_one_idx , epoch)if(epoch+1) %10 == 0:cp_filename = saver.save(sess , 保存路徑) #保存訓(xùn)練模型def eval_one_epoch(epoch_num):total_label_acc_per_cat = np.zeros[16] #每一類物體分類標(biāo)簽的正確數(shù)total_seg_acc_per_cat = np.zeros[16] #每一類分割正確數(shù)total_seen_per_cat = np.zeros[16] #每類個數(shù)for i in range(num_test_file):cur_data = (2048,2048,3) #測試集的點(diǎn)云數(shù)據(jù)cur_labels =(2048,16) #點(diǎn)云數(shù)據(jù)物體對應(yīng)的16類cur_seg = (2048,2048) #每個點(diǎn)對應(yīng)的50類其中之一cur_labels_one_hot = convert_label_to_one_hot(cur_labels)"""將label都換為one_hot形式"""for j in range(num_batch): #按批次運(yùn)行beginidx-----endidx #開始到結(jié)束的索引loss = sess.run()per_instance_part_acc = np.mean(pred_seg_res == cur_seg[begidx:endidx,...],axis=1 )"""求每個物體的零件正確率"""average_part_acc = np.mean(per_instance_part_acc)"""求這32個物體的平均零件正確率"""per_instance_label_pred = np.argmax(label_pred_val, axis=1)""" 求出這32個物體對類別預(yù)測的標(biāo)簽 """total_label_acc += np.mean(np.float32(per_instance_label_pred == cur_labels[begidx: endidx, ...]))""" 算出預(yù)測標(biāo)簽的正確率并求平均進(jìn)行累加"""total_seg_acc += average_part_acc """將平均零件分割正確率累加"""for shape_idx in range(begidx, endidx):total_seen_per_cat[cur_labels[shape_idx]] += 1"""test過的每一類的個數(shù)"""total_label_acc_per_cat[cur_labels[shape_idx]]+=np.int32(per_instance_label_pred[shape_idx-begidx] == cur_labels[shape_idx])"""每一類標(biāo)簽判斷正確的個數(shù):預(yù)測標(biāo)簽與正確標(biāo)簽對比,如果正確就在相應(yīng)位置+1"""total_seg_acc_per_cat[cur_labels[shape_idx]] += per_instance_part_acc[shape_idx - begidx]"""將每個物體分割的正確率累加"""total_loss = total_loss * 1.0 / total_seentotal_label_loss = total_label_loss * 1.0 / total_seentotal_seg_loss = total_seg_loss * 1.0 / total_seentotal_label_acc = total_label_acc * 1.0 / total_seentotal_seg_acc = total_seg_acc * 1.0 / total_seentrain_one_proch比eval_one_peoch 多一個優(yōu)化器過程
test.py
自行定義命令參數(shù)獲取 model_path ,保存的訓(xùn)練模型pretrained_model_path = FLAGS.model_path #獲取保存好的模型 hdf5_data_dir = './hdf5_data' # 獲取h5數(shù)據(jù)集 ply_data_dir = './PartAnnotation' # 導(dǎo)入測試數(shù)據(jù)集test_file_list = os.path.join(BASE_DIR, 'testing_ply_file_list.txt') """ testing_ply_file_list.txt為從PartAnnotation數(shù)據(jù)集中采樣出的2874個數(shù)據(jù),分別包括點(diǎn)云數(shù)據(jù) / 分割數(shù)據(jù) / 實(shí)例類別編號"""oid2cpid = 'overallid_to_catid_partid.json' """oid2cpid讀取物體零件編號[["02691156", 1], ["02691156", 2],....]""" object2setofoid = {} #oid對象集 for idx in range(len(oid2cpid)):objid, pid = oid2cpid[idx] #objid對象標(biāo)識符 pid編號if not objid in object2setofoid.keys():object2setofoid[objid] = []object2setofoid[objid].append(idx) """創(chuàng)建一個字典,將每個物體編號按順序0~49索引排序{'02691156':[0,1,2,3],'02773838':[4,5],.....}"""all_obj_cat_file = 'all_object_categories.txt' 獲取16類物體和編號的文件,并分別劃分到兩個列表中 objcats = split()[0] """['02691156','02773838',......]""" objnames = split()[1] """['Airplane','Bag',......]"""color_map = json.load('part_color_mapping.json') 獲取顏色cpid2oid = 'catid_partid_to_overallid.json' """cpid2oid為對物體零件進(jìn)行分類1~50類對應(yīng){"03642806_2": 29, "03642806_1": 28,...."""------------------------------------數(shù)據(jù)集的前期處理全部完成----------------------------------- def predict():pointclouds_ph = (1,3000,3)input_label_ph = (1,16)pred , seg_pred , end_points = get_model(pointclouds_ph, input_label_ph,...)"""模型占位符"""saver = tf.train.Saver()"""添加操作用來保存和重現(xiàn)所有變量"""with tf.Seesion(config=config) as sess:saver.restore(sess, pretrained_model_path)"""導(dǎo)入訓(xùn)練好的模型"""batch_data = np.zeros[1,3000,3]total_per_cat_acc = np.zeros(16)"""每一類正確的個數(shù)"""total_per_cat_iou = np.zeros(16)""" 每一類的IOU"""total_per_cat_seen = np.zeros(16)""" 每一類測試的總個數(shù)"""獲取測試用的數(shù)據(jù)集test_file_list,并進(jìn)行預(yù)處理,將其劃分為3類列表pts_files = split()[0] """獲取的點(diǎn)云文件路徑"""seg_files = split()[1]"""獲取seg文件路徑"""labels = split()[2]""" 獲取物體類別編號""""""開始逐個對測試數(shù)據(jù)集中的數(shù)據(jù)進(jìn)行操作,測試數(shù)據(jù)有2874個"""for shape_idx in range(len_pts_files):cur_gt_label = on2oid[labels[shape_idx]]""" on2oid為物體編號對應(yīng)索引,總共有16個,獲取當(dāng)前數(shù)據(jù)集的編號對應(yīng)索引"""將其轉(zhuǎn)換為獨(dú)熱編碼pts_file_to_load = os.path.join(ply_data_dir, pts_files[shape_idx])seg_file_to_load = os.path.join(ply_data_dir, seg_files[shape_idx])"""根據(jù)shape_idx將pts文件和seg文件讀取出來"""pts, seg = load_pts_seg_files(pts_file_to_load, seg_file_to_load, objcats[cur_gt_label])"""將各物體編號都統(tǒng)一到1~50類當(dāng)中,這個操作非常關(guān)鍵!!!!! """def load_pts_seg_files(pts_file, seg_file, catid):with open(pts_file, 'r') as f:pts_str = [item.rstrip() for item in f.readlines()]pts = np.array([np.float32(s.split()) for s in pts_str], dtype=np.float32)a = len(pts) with open(seg_file, 'r') as f:part_ids = np.array([int(item.rstrip()) for item in f.readlines()], dtype=np.uint8)"""在單獨(dú)一個物體中以1,2,3將不同零件進(jìn)行分類,得出的零件索引[2 2 2 1 1 1 1 1 ....]"""seg = np.array([cpid2oid[catid+'_'+str(x)] for x in part_ids])"""cpid2oid為每個物體零件對應(yīng)的0~50類編號,將單個物體零件的分類通過cpid2oid轉(zhuǎn)換為總的50類別"""label_pred_val , seg_pred_res = sess.run()""" 預(yù)測出的label 和 seg"""label_pred_val = np.argmax(label_pred_val[0, :]) """將預(yù)測出的label得出"""seg_pred_res = seg_pred_res[0,....] #進(jìn)行降維處理c = seg_pred_res.shaoe #(3000,50)iou_oids = object2setofoid[objcats[cur_gt_label]]""" 將該物體的零件索引提取出來objacts:['02691156','02773838',......]object2setofoid:{'02691156':[0,1,2,3],'02773838':[4,5],.....}[12,13,14,15]"""non_cat_labels = list(set(np.arange(NUM_PART_CATS)).difference(set(iou_oids))) """創(chuàng)建一個0~49的數(shù)組,剔除12,13,14,15"""mini = np.min(seg_pred_res) #獲取預(yù)測中的最小值seg_pred_res[:, non_cat_labels] = mini - 1000 #將除12,13,14,15的其他標(biāo)簽都減小seg_pred_val = np.argmax(seg_pred_res, axis=1)[:ori_point_num]"""比較12,13,14,15這個位置的數(shù),取最大判斷為該類"""seg_acc = np.mean(seg_pred_val == seg)"""預(yù)測的類與正確實(shí)際的類做比較,得出seg的正確率"""total_acc += seg_acc"""將分割的正確率進(jìn)行累加"""total_seen += 1""" 測試總的個數(shù)"""total_per_cat_seen[cur_gt_label] += 1total_per_cat_acc[cur_gt_label] += seg_accmask = np.int32(seg_pred_val == seg)"""預(yù)測類與正確的比較,相等為1,不等為0""" 計(jì)算IOU = n_intersect/(n_pred + n_seg - n_intersect)n_pred = 預(yù)測的12標(biāo)簽的個數(shù)n_seg = 實(shí)際的12標(biāo)簽的個數(shù)n_intersect = 判斷正確的12標(biāo)簽的個數(shù)"""對預(yù)測結(jié)果,保存在obj文件"""if output_verbose:output_color_point_cloud(pts, seg, os.path.join(output_dir, str(shape_idx)+'_gt.obj'))output_color_point_cloud(pts, seg_pred_val, os.path.join(output_dir, str(shape_idx)+'_pred.obj'))output_color_point_cloud_red_blue(pts, np.int32(seg == seg_pred_val), os.path.join(output_dir, str(shape_idx)+'_diff.obj'))總結(jié)
以上是生活随笔為你收集整理的Pointnet(part_seg)train.py,test.py代码随记的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 360图书馆自动全文.
- 下一篇: 半桥BUCK电路—记录篇