7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
生活随笔
收集整理的這篇文章主要介紹了
7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
前言
在TensorFlow中進行模型訓練時,在官網(wǎng)給出的三種讀取方式,中最好的文件讀取方式就是將利用隊列進行文件讀取,而且步驟有兩步:
1. 把樣本數(shù)據(jù)寫入TFRecords二進制文件
2. 從隊列中讀取
TFRecords二進制文件,能夠更好的利用內(nèi)存,更方便的移動和復制,并且不需要單獨的標記文件
下面官網(wǎng)給出的,對mnist文件進行操作的code,具體代碼請參考:tensorflow-master\tensorflow\examples\how_tos\reading_data\convert_to_records.py
CODE
源碼與解析
解析主要在注釋里
import tensorflow as tf import os import argparse import sysos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'#1.0 生成TFRecords 文件 from tensorflow.contrib.learn.python.learn.datasets import mnistFLAGS = None# 編碼函數(shù)如下: def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert_to(data_set, name):"""Converts a dataset to tfrecords."""images = data_set.imageslabels = data_set.labelsnum_examples = data_set.num_examplesif images.shape[0] != num_examples:raise ValueError('Images size %d does not match label size %d.' %(images.shape[0], num_examples))rows = images.shape[1] # 28cols = images.shape[2] # 28depth = images.shape[3] # 1. 是黑白圖像,所以是單通道filename = os.path.join(FLAGS.directory, name + '.tfrecords')print('Writing', filename)writer = tf.python_io.TFRecordWriter(filename)for index in range(num_examples):image_raw = images[index].tostring()# 寫入?yún)f(xié)議緩存區(qū),height,width,depth,label編碼成int64類型,image_raw 編碼成二進制example = tf.train.Example(features=tf.train.Features(feature={'height': _int64_feature(rows),'width': _int64_feature(cols),'depth': _int64_feature(depth),'label': _int64_feature(int(labels[index])),'image_raw': _bytes_feature(image_raw)}))writer.write(example.SerializeToString()) # 序列化為字符串writer.close()def main(unused_argv):# Get the data.data_sets = mnist.read_data_sets(FLAGS.directory,dtype=tf.uint8,reshape=False,validation_size=FLAGS.validation_size)# Convert to Examples and write the result to TFRecords.convert_to(data_sets.train, 'train')convert_to(data_sets.validation, 'validation')convert_to(data_sets.test, 'test')if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--directory',type=str,default='MNIST_data/',help='Directory to download data files and write the converted result')parser.add_argument('--validation_size',type=int,default=5000,help="""\Number of examples to separate from the training data for the validationset.\""")FLAGS, unparsed = parser.parse_known_args()tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)運行結(jié)果
打印輸出
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz Writing MNIST_data/train.tfrecords Writing MNIST_data/validation.tfrecords Writing MNIST_data/test.tfrecords文件
相關(guān)
總結(jié)
以上是生活随笔為你收集整理的7.2 TensorFlow笔记(基础篇): 生成TFRecords文件的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 6.1 Tensorflow笔记(基础篇
- 下一篇: 7.1 TensorFlow笔记(基础篇