使用object detection训练并识别自己的模型
1.安裝tensorflow(version>=1.4.0)
2.部署tensorflow models
-?在這里下載
- 解壓并安裝
- 解壓后重命名為models復制到tensorflow/目錄下
- 在linux下
- 進入tensorflow/models/research/目錄,運行protoc object_detection/protos/*.proto --python_out=.
- 在~/.bashrc file.中添加slim和models/research路徑
export PYTHONPATH=$PYTHONPATH:/path/to/slim:/path/to/research
- 在windows下
- 下載protoc-3.3.0-win32.zip(version==3.3,已知3.5版本會報錯)?
- 解壓后將protoc.exe放入C:\Windows下
- 在tensorflow/models/research/打開powershell,運行protoc object_detection/protos/*.proto --python_out=.
3.訓練數據準備(標記分類的圖片)
- 安裝labelImg?用來手動標注圖片 ,圖片需要是png或者jpg格式
- 標注信息會被保存為xml文件,使用?這個腳本?將所有xml文件轉換為一個csv文件(xml文件路徑識別在29行,根據情況自己修改)
- 把生成的csv文件分成訓練集和測試集
4.生成TFRecord文件
- 使用?這個腳本?將兩個csv文件生成出兩個TFRecord文件(訓練自己的模型,必須使用TFRecord格式文件。圖片路徑識別在86行,根據情況自己修改)
5.創建label map文件
id需要從1開始,class-N便是自己需要識別的物體類別名,文件后綴為.pbtxt
item{
id:1
name: 'class-1'
}
item{
id:2
name: 'class-2'
}
6.下載模型并配置文件
- 下載一個模型(文件后綴.tar.gz)
- 修改對應的訓練pipline配置文件?
- 查找文件中的PATH_TO_BE_CONFIGURED字段,并做相應修改
- num_classes 改為你模型中包含類別的數量
- fine_tune_checkpoint 解壓.tar.gz文件后的路徑 + /model.ckpt
- from_detection_checkpoint:true
- train_input_reader
- input_path 由train.csv生成的record格式訓練數據
- label_map_path 第5步創建的pbtxt文件路徑
- eval_input_reader
- input_path 由test.csv生成的record格式訓練數據
- label_map_path 第5步創建的pbtxt文件路徑
7. 訓練模型
- 進入tensorflow/models/research/目錄,運行
python object_detection/train.py?--logtostderr? --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG}?//第六步中修改的pipline配置文件路徑//?--train_dir=${PATH_TO_TRAIN_DIR}?//生成的模型保存路徑//
8.導出模型
- 在第7步中,--train_dir指向的路徑中會生成一系列訓練中自動保存的checkpoint,一個checkpoint由三個文件組成,后綴分別是.data-00000-of-00001 .index和.meta,任然在第7步的路徑中,運行
python object_detection/export_inference_graph.py \
--input_type image_tensor? \
--pipeline_config_path?${PIPELINE_CONFIG_PATH}?//第六步中修改的pipline配置文件路徑\--trained_checkpoint_prefix?${TRAIN_PATH}?//上述的一個checkpoint,例如model.ckpt-112254 \ --output_directory?${OUTPUT_PATH}?//輸出模型文件的路徑//
9.使用新模型識別圖片
調用predict.py
首先導入包
import time import cv2 import numpy as np import tensorflow as tf import pandas as pd import math import osfrom object_detection.utils import label_map_util from object_detection.utils import visualization_utils as vis_util然后定義類和函數
class TOD(object):def __init__(self):self.PATH_TO_CKPT = r'D:/xiangchuang/new_train_model/result/frozen_inference_graph.pb'self.PATH_TO_LABELS = r'D:/xiangchuang/pig.pbtxt'self.NUM_CLASSES = 1self.detection_graph = self._load_model()self.category_index = self._load_label_map()def _load_model(self):global detection_graphdetection_graph = tf.Graph()with detection_graph.as_default():od_graph_def = tf.GraphDef()with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:serialized_graph = fid.read()od_graph_def.ParseFromString(serialized_graph)tf.import_graph_def(od_graph_def, name='')return detection_graphdef _load_label_map(self):label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes=self.NUM_CLASSES,use_display_name=True)category_index = label_map_util.create_category_index(categories)return category_indexdef detect(self, image):image_np_expanded = np.expand_dims(image, axis=0)image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')scores = self.detection_graph.get_tensor_by_name('detection_scores:0')classes = self.detection_graph.get_tensor_by_name('detection_classes:0')num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')# Actual detection.(boxes, scores, classes, num_detections) = sess.run([boxes, scores, classes, num_detections],feed_dict={image_tensor: image_np_expanded})# Visualization of the results of a detection.vis_util.visualize_boxes_and_labels_on_image_array(image,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),self.category_index,use_normalized_coordinates=True,line_thickness=8)cv2.namedWindow("detection", cv2.WINDOW_NORMAL)cv2.imshow("detection", image)cv2.waitKey(1)最后執行
if __name__ == '__main__':detector = TOD()with detection_graph.as_default():with tf.Session(graph=detection_graph) as sess:cap = cv2.VideoCapture(r'Your Vedio Path')n = 1success = Truewhile (success) :success, frame = cap.read()t1=time.clock()print('正在預測第%d張' % n)n = n + 1if success == True:detector.detect(frame)t2=time.clock()t = t2-t1print('cost time %f s'%t)cv2.destroyAllWindows()即可以實現基于視頻的目標目標檢測
?
參考文檔
https://gist.github.com/douglasrizzo/c70e186678f126f1b9005ca83d8bd2ce
https://towardsdatascience.com/how-to-train-your-own-object-detector-with-tensorflows-object-detector-api-bec72ecfe1d9
總結
以上是生活随笔為你收集整理的使用object detection训练并识别自己的模型的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深度学习中,Batch_Normaliz
- 下一篇: ncnn源码编译安装