yolov3-tf2 数据格式压缩
生活随笔
收集整理的這篇文章主要介紹了
yolov3-tf2 数据格式压缩
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
對于voc 這樣的圖像數據集,占用空間比較大,之前一般以矩陣形式存在內存空間中進行模型訓練,需要計算機大量內存空間
tensorflow 有 tf.io.TFRecordWriter的數據api 可以將數據進行壓縮,
import time import os import hashlibfrom absl import app, flags, logging from absl.flags import FLAGS import tensorflow as tf import lxml.etree import tqdm# flags.DEFINE_string('data_dir', './data/voc2012_raw/VOCdevkit/VOC2012/', # 'path to raw PASCAL VOC dataset') # flags.DEFINE_enum('split', 'train', [ # 'train', 'val'], 'specify train or val spit') # flags.DEFINE_string('output_file', './data/voc2012_train.tfrecord', 'outpot dataset') # flags.DEFINE_string('classes', './data/voc2012.names', 'classes file')def build_example(annotation, class_map):img_path = os.path.join('./data/voc2012_raw/VOCdevkit/VOC2012', 'JPEGImages', annotation['filename'])img_raw = open(img_path, 'rb').read()key = hashlib.sha256(img_raw).hexdigest()width = int(annotation['size']['width'])height = int(annotation['size']['height'])xmin = []ymin = []xmax = []ymax = []classes = []classes_text = []truncated = []views = []difficult_obj = []if 'object' in annotation:for obj in annotation['object']:difficult = bool(int(obj['difficult']))difficult_obj.append(int(difficult))xmin.append(float(obj['bndbox']['xmin']) / width)ymin.append(float(obj['bndbox']['ymin']) / height)xmax.append(float(obj['bndbox']['xmax']) / width)ymax.append(float(obj['bndbox']['ymax']) / height)classes_text.append(obj['name'].encode('utf8'))classes.append(class_map[obj['name']])truncated.append(int(obj['truncated']))views.append(obj['pose'].encode('utf8'))example = tf.train.Example(features=tf.train.Features(feature={'image/height': tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),'image/width': tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),'image/filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[annotation['filename'].encode('utf8')])),'image/source_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[annotation['filename'].encode('utf8')])),'image/key/sha256': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf8')])),'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),'image/format': tf.train.Feature(bytes_list=tf.train.BytesList(value=['jpeg'.encode('utf8')])),'image/object/bbox/xmin': tf.train.Feature(float_list=tf.train.FloatList(value=xmin)),'image/object/bbox/xmax': tf.train.Feature(float_list=tf.train.FloatList(value=xmax)),'image/object/bbox/ymin': tf.train.Feature(float_list=tf.train.FloatList(value=ymin)),'image/object/bbox/ymax': tf.train.Feature(float_list=tf.train.FloatList(value=ymax)),'image/object/class/text': tf.train.Feature(bytes_list=tf.train.BytesList(value=classes_text)),'image/object/class/label': tf.train.Feature(int64_list=tf.train.Int64List(value=classes)),'image/object/difficult': tf.train.Feature(int64_list=tf.train.Int64List(value=difficult_obj)),'image/object/truncated': tf.train.Feature(int64_list=tf.train.Int64List(value=truncated)),'image/object/view': tf.train.Feature(bytes_list=tf.train.BytesList(value=views)),}))return exampledef parse_xml(xml):if not len(xml):return {xml.tag: xml.text}result = {}for child in xml:child_result = parse_xml(child)if child.tag != 'object':result[child.tag] = child_result[child.tag]else:if child.tag not in result:result[child.tag] = []result[child.tag].append(child_result[child.tag])return {xml.tag: result}def main():#讀取數據,并且對類別進行編碼,這是一個人字典class_map = {name: idx for idx, name in enumerate(open('./data/voc2012.names').read().splitlines())}#在這里,用tf2 的新的數據格式,將圖像信息重新編碼,而不在是直接用矩陣格式#原因是編碼后數據占用空間比較低,writer = tf.io.TFRecordWriter('./data/voc2012_train.tfrecord')image_list = open(os.path.join('./data/voc2012_raw/VOCdevkit/VOC2012', 'ImageSets', 'Main', '%s.txt' % 'train')).read().splitlines()logging.info("Image list loaded: %d", len(image_list))for name in tqdm.tqdm(image_list):annotation_xml = os.path.join('./data/voc2012_raw/VOCdevkit/VOC2012', 'Annotations', name + '.xml')annotation_xml = lxml.etree.fromstring(open(annotation_xml).read())annotation = parse_xml(annotation_xml)['annotation']tf_example = build_example(annotation, class_map)writer.write(tf_example.SerializeToString())writer.close()logging.info("Done")main()# if __name__ == '__main__': # app.run(main)總結
以上是生活随笔為你收集整理的yolov3-tf2 数据格式压缩的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: tf.lookup.StaticHash
- 下一篇: yolov3 -tf 解析数据