分类任务4——用dataset读取tfrecord
這里用了matplotlib來驗證是否成功生成dataset。
第一個需要注意的是,在decode_example()中用了一個tf.reshape,這個函數里的參數應該和前面制作的tfrecord的圖像參數一致,也就是如果前面tfrecord的圖像大小是600,那么此時的tf.reshape里應該是[600, 600, 3];而如果是224,那就是[224, 224, 3]
還需要注意的是在用dataset讀取的時候,如果最后一個batch的大小和我們自行定義的batch_size不符合,就會引起tf.errors.OutOfRangeError,這個時候需要用異常處理try, except,代碼示意為:
try:_, _ = sess.run([_, _]) # _所代表的內容根據自身的實驗添加 except tf.errors.OutOfRangeError:print('it is the end')break此外還需要注意的是,在iterator的選擇上:
如果是用于訓練和驗證,應該選擇iterator = dataset.make_initializable_iterator()。因為使用這個可以循環輸入,保持一個train.batch對應一個validation_batch。
如果是用于測試,可以選擇上面的選擇iterator = dataset.make_one_shot_iterator()。因為只需要一次遍歷。
最后需要注意的是,iteration和epoch之間的關系。如果在代碼中加入while True:,那么上面的i代表了epoch,即一個i期間完成了一次完整的數據集遍歷;如果 沒有 在代碼中加入while True:,那么上面的i代表了iteration,即一個i期間完成了一次完整的batch_size遍歷。
平時常說的訓練幾萬次就是指平時常說的訓練幾萬次就是指平時常說的訓練幾萬次就是指幾萬次 iteration。
附上完整的代碼:
"""# 使用dataset讀取tfrecord""" import tensorflow as tf import matplotlib.pyplot as plt# 單個record的解析函數 def decode_example(example, resize_height, resize_width, label_nums):dics = {'image_raw': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64)}parsed_example = tf.parse_single_example(serialized=example, features=dics)tf_image = tf.decode_raw(parsed_example['image_raw'], out_type=tf.uint8) # 這個其實就是圖像的像素模式,之前我們使用矩陣來表示圖像# 對圖像的尺寸進行調整,調整成三通道圖像,如果前面制作tfrecord時的圖像大小是多少,這里就要寫多少。# 比如tfrecord的圖像是224,這里就寫[224,224,3];如果是600,就是[600, 600, 3]tf_image = tf.reshape(tf_image, shape=[600, 600, 3])tf_image = tf.image.resize_images(tf_image, (resize_height, resize_width), method=2)tf_image = tf.cast(tf_image, tf.float32) * (1. / 255) # 對圖像進行歸一化以便保持和原圖像有相同的精度tf_label = tf.cast(parsed_example['label'], tf.int64)tf_label = tf.one_hot(tf_label, label_nums, on_value=1, off_value=0) # 將label轉化成用one_hot編碼的格式return tf_image, tf_label# 生成dataset,可以引入其他py文件中。 def create_dataset(tfrecords_file, batch_size, resize_height, resize_width, num_class):dataset = tf.data.TFRecordDataset(tfrecords_file)# map函數可以讓dataset的結構和上面的decode_example一致。# lambda個人推測是為了構建上面的decode_example,但是如果decode函數除了example之外沒有其他變量,那就不需要用lambda# 但是這里的lambda里只有一個x,對應的example,而且這個x不需要外部函數傳值,這一點很困惑。# 我個人覺得是在tf.data.TFRecordDataset創建了dataset之后,dataset里就包含了serialized對應的值# 因此可以函數內部自己傳值,不需要外部設置變量再傳值。dataset = dataset.map(lambda x: decode_example(x, resize_height, resize_width, num_class))dataset = dataset.shuffle(2000)dataset = dataset.batch(batch_size)return datasetif __name__ == '__main__':# 定義可以一次獲得多張圖像的函數def show_image(image_dir):plt.imshow(image_dir)plt.axis('on')plt.show()# 檢查dataset是否正確生成def batch_test(tfrecords_file, batch_size, resize_height, resize_width, num_class):dataset = create_dataset(tfrecords_file, batch_size, resize_height, resize_width, num_class)# iterator = dataset.make_one_shot_iterator()# iterator = tf.data.Iterator.from_structure(dataset.output_types,# dataset.output_shapes)# iterator1 = iterator.make_initializer(dataset)iterator = dataset.make_initializable_iterator()batch_images, batch_labels = iterator.get_next()init_op = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init_op)coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)# 查看tfrecord的樣本數量sample_nums = 0for record in tf.python_io.tf_record_iterator(tfrecords_file):sample_nums += 1print('\n#######\n')print('this tfrecord file: "{}" has _{}_ samples'.format(tfrecords_file, sample_nums))print('\n#######\n')for i in range(100):# sess.run(iterator1)sess.run(iterator.initializer)# 如果加上while True,那么上面的i代表著一個epoch,也就是將數據集的全部樣本都完成一遍;# 如果不加上while True,那么上面的i就代表著一個iteration,也就是使用一個batch_size完成一次。# iteration和epoch之間的關系為:1 epoch (所代表的數量)= batch_size * iteration# 在這里我選擇的是iteration,如果需要epoch可以在添加while True之后再將try-break之間的代碼tab一下就行。try:# show_image(images[5,:,:,:]) # 代表每一個batch的第三張圖片images, labels = sess.run([batch_images, batch_labels])print('{}th, image.shape:{}, type:{}, labels.shape:{}'.format(i + 1, images.shape, images.dtype,labels.shape))except tf.errors.OutOfRangeError:print('\n******\n')print('{}th batch is the final batch, total iteration is: {} '.format(i, i))print('\n******\n')breakcoord.request_stop()coord.join(threads)tfrecords_file = 'E:/111project/tfrecord/validation.tfrecords'resize_height = 100resize_width = 100num_class = 5batch_test(tfrecords_file, 200, resize_height, resize_width, num_class)總結
以上是生活随笔為你收集整理的分类任务4——用dataset读取tfrecord的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: HBuilder快捷键整理集合
- 下一篇: VMware Workstation P