tf.data.Dataset与tfrecord学习笔记
目錄
1.tf.data.Dataset
2.tfrecord
2.1 使用tfrecord的原因
2.2 tfrecord的寫入
2.3 tfrecord的讀取
3.兩種方式的區別
參考資料:
1.tf.data.Dataset
# 從tensor中獲取數據
dataset = tf.data.Dataset.from_tensor_slices(img_paths)
?
# 可選項,從數據集中過濾數據
dataset = dataset.filter(filter)
?
# 數據解析,原來可能是路徑,需要變成真實的圖片數據
# 其中num_parallel_calls表示并行操作的線程數量,一般設置為CPU核心數量為最好
dataset = dataset.map(map_func, num_parallel_calls=num_threads)
?
# 打亂數據,這里有個buffer_size,表示每次從這個buffer_size個數據中隨機一個位置,與buffer_size外的數據進行交換
dataset = dataset.shuffle(buffer_size)
?
# 把數據集組裝成batchs
if drop_remainder:
????# 將dataset切分成n個batch_size,并且決定是否丟掉最后一個不滿足一個batch的數據
????dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
????dataset = dataset.batch(batch_size)
?
# repeat表示重復的次數,-1表示重復無限次,這樣就永遠不會報outOfRange這種錯
"""
tf.data.Dataset.prefetch 提供了 software pipelining 機制。該函數解耦了 數據產生的時間 和 數據消耗的時間。
具體來說,該函數有一個后臺線程和一個內部緩存區,在數據被請求前,就從 dataset 中預加載一些數據(進一步提高性能)。
prefech(n) 一般作為最后一個 transformation,其中 n 為 batch_size。
因為數據已經分成了好幾個batch,那么這句話其實就是在數據被請求之前,預加載2個batch的數據
"""
dataset = dataset.repeat(repeat).prefetch(prefetch_batch)
?
# 建立一個迭代器
iterator =?dataset.make_initializable_iterator()
batch_data = iterator.get_next()
?
# 建立一個會話Session
with tf.Session as sess:
? ? # 初始化迭代器
? ? sess.run(iterator.initializer)
? ? data = sess.run(batch_data)
實測代碼:
import osimport numpy as np import tensorflow as tf from tflib.utils import session import randomimg_paths = "E:\\python_project\\DeeCamp\\data\\list_attr_celeba.txt" buffer_size = 4096 drop_remainder = True batch_size = 32 repeat = -1 prefetch_batch = 2names = np.loadtxt(img_paths, skiprows=2, usecols=[0], dtype=np.str)print("start") # 從tensor中獲取數據 dataset = tf.data.Dataset.from_tensor_slices(names)print("read files over")# 可選項,從數據集中過濾數據 # dataset = dataset.filter()# 數據解析,原來可能是路徑,需要變成真實的圖片數據 # 其中num_parallel_calls表示并行操作的線程數量,一般設置為CPU核心數量為最好 # dataset = dataset.map(map_func, num_parallel_calls=num_threads)# 打亂數據,這里有個buffer_size,表示每次從這個buffer_size個數據中隨機一個位置,與buffer_size外的數據進行交換 dataset = dataset.shuffle(buffer_size)# 把數據集組裝成batchs if drop_remainder:# 將dataset切分成n個batch_size,并且決定是否丟掉最后一個不滿足一個batch的數據dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(batch_size)) else:dataset = dataset.batch(batch_size)# repeat表示重復的次數,-1表示重復無限次,這樣就永遠不會報outOfRange這種錯 """ tf.data.Dataset.prefetch 提供了 software pipelining 機制。該函數解耦了 數據產生的時間 和 數據消耗的時間。 具體來說,該函數有一個后臺線程和一個內部緩存區,在數據被請求前,就從 dataset 中預加載一些數據(進一步提高性能)。 prefech(n) 一般作為最后一個 transformation,其中 n 為 batch_size。 因為數據已經分成了好幾個batch,那么這句話其實就是在數據被請求之前,預加載2個batch的數據 """ dataset = dataset.repeat(repeat).prefetch(prefetch_batch)# 建立一個迭代器 iterator = dataset.make_initializable_iterator() batch_data = iterator.get_next()# 建立一個會話Session with tf.Session() as sess:# 初始化迭代器sess.run(iterator.initializer)for i in range(10):data = sess.run(batch_data)print(data)print("="*20)結果如下:
[b'002411.jpg' b'001866.jpg' b'003657.jpg' b'000849.jpg' b'002705.jpg'
b'002485.jpg' b'000120.jpg' b'002057.jpg' b'003620.jpg' b'003092.jpg'
b'003111.jpg' b'000557.jpg' b'003030.jpg' b'001831.jpg' b'001967.jpg'
b'000258.jpg' b'002366.jpg' b'004102.jpg' b'000067.jpg' b'002444.jpg'
b'003847.jpg' b'003876.jpg' b'000516.jpg' b'002107.jpg' b'003941.jpg'
b'004006.jpg' b'000632.jpg' b'000080.jpg' b'002286.jpg' b'003046.jpg'
b'000785.jpg' b'001122.jpg']
====================
[b'001681.jpg' b'001067.jpg' b'002432.jpg' b'002173.jpg' b'001478.jpg'
b'000610.jpg' b'001715.jpg' b'002695.jpg' b'004003.jpg' b'004100.jpg'
b'002240.jpg' b'000286.jpg' b'003298.jpg' b'000760.jpg' b'003712.jpg'
b'003076.jpg' b'003598.jpg' b'000423.jpg' b'003211.jpg' b'002405.jpg'
b'001274.jpg' b'003872.jpg' b'004079.jpg' b'000486.jpg' b'004012.jpg'
b'003247.jpg' b'001156.jpg' b'004073.jpg' b'002359.jpg' b'000636.jpg'
b'000349.jpg' b'001392.jpg']
====================
[b'001704.jpg' b'001051.jpg' b'002887.jpg' b'003227.jpg' b'000357.jpg'
b'003706.jpg' b'003297.jpg' b'004016.jpg' b'002112.jpg' b'002975.jpg'
b'004077.jpg' b'002272.jpg' b'001991.jpg' b'000694.jpg' b'001515.jpg'
b'000242.jpg' b'002169.jpg' b'003926.jpg' b'001462.jpg' b'002646.jpg'
b'003214.jpg' b'000487.jpg' b'000326.jpg' b'001344.jpg' b'001069.jpg'
b'003025.jpg' b'002724.jpg' b'002502.jpg' b'002479.jpg' b'004098.jpg'
b'001749.jpg' b'003203.jpg']
====================
[b'003235.jpg' b'003145.jpg' b'000356.jpg' b'003175.jpg' b'001426.jpg'
b'003209.jpg' b'004105.jpg' b'002073.jpg' b'003118.jpg' b'003629.jpg'
b'001634.jpg' b'003120.jpg' b'000098.jpg' b'001096.jpg' b'001607.jpg'
b'003158.jpg' b'004115.jpg' b'000084.jpg' b'003362.jpg' b'003666.jpg'
b'001573.jpg' b'002369.jpg' b'002097.jpg' b'003621.jpg' b'003484.jpg'
b'003809.jpg' b'001107.jpg' b'001207.jpg' b'003556.jpg' b'003763.jpg'
b'003594.jpg' b'001101.jpg']
====================
[b'000073.jpg' b'003798.jpg' b'002839.jpg' b'002614.jpg' b'002921.jpg'
b'002453.jpg' b'003261.jpg' b'002648.jpg' b'002605.jpg' b'003388.jpg'
b'003010.jpg' b'000752.jpg' b'003783.jpg' b'001673.jpg' b'002732.jpg'
b'002936.jpg' b'001997.jpg' b'003518.jpg' b'001005.jpg' b'002789.jpg'
b'001082.jpg' b'003087.jpg' b'003873.jpg' b'003871.jpg' b'001441.jpg'
b'003494.jpg' b'000135.jpg' b'001564.jpg' b'000410.jpg' b'002700.jpg'
b'001258.jpg' b'003723.jpg']
====================
[b'000948.jpg' b'003301.jpg' b'003280.jpg' b'001173.jpg' b'002086.jpg'
b'001553.jpg' b'001125.jpg' b'003796.jpg' b'002469.jpg' b'000866.jpg'
b'003491.jpg' b'003708.jpg' b'004152.jpg' b'001616.jpg' b'003965.jpg'
b'002069.jpg' b'002966.jpg' b'000739.jpg' b'001433.jpg' b'000419.jpg'
b'001955.jpg' b'003578.jpg' b'003493.jpg' b'000992.jpg' b'001333.jpg'
b'004042.jpg' b'003442.jpg' b'001623.jpg' b'003615.jpg' b'004140.jpg'
b'003635.jpg' b'000619.jpg']
====================
[b'002193.jpg' b'002691.jpg' b'000456.jpg' b'002500.jpg' b'001423.jpg'
b'003624.jpg' b'002149.jpg' b'000743.jpg' b'001570.jpg' b'002141.jpg'
b'002891.jpg' b'000467.jpg' b'002985.jpg' b'003384.jpg' b'003971.jpg'
b'003143.jpg' b'001541.jpg' b'003032.jpg' b'002317.jpg' b'003951.jpg'
b'001980.jpg' b'000183.jpg' b'002111.jpg' b'001115.jpg' b'000163.jpg'
b'000381.jpg' b'004301.jpg' b'001529.jpg' b'002506.jpg' b'003976.jpg'
b'003886.jpg' b'002414.jpg']
====================
[b'001126.jpg' b'000007.jpg' b'002410.jpg' b'002568.jpg' b'003724.jpg'
b'002947.jpg' b'003988.jpg' b'004004.jpg' b'002682.jpg' b'003284.jpg'
b'000003.jpg' b'001234.jpg' b'001080.jpg' b'002395.jpg' b'000085.jpg'
b'002064.jpg' b'000646.jpg' b'003652.jpg' b'004264.jpg' b'000577.jpg'
b'004320.jpg' b'003726.jpg' b'003859.jpg' b'001369.jpg' b'001056.jpg'
b'003422.jpg' b'003193.jpg' b'001178.jpg' b'000918.jpg' b'000509.jpg'
b'000296.jpg' b'003273.jpg']
====================
[b'001889.jpg' b'003185.jpg' b'000029.jpg' b'002218.jpg' b'001762.jpg'
b'003392.jpg' b'002634.jpg' b'001382.jpg' b'001100.jpg' b'000779.jpg'
b'000544.jpg' b'003537.jpg' b'002630.jpg' b'004138.jpg' b'000539.jpg'
b'002091.jpg' b'000378.jpg' b'002754.jpg' b'002377.jpg' b'002861.jpg'
b'002858.jpg' b'003162.jpg' b'002898.jpg' b'004361.jpg' b'003058.jpg'
b'001686.jpg' b'000629.jpg' b'002349.jpg' b'001722.jpg' b'002675.jpg'
b'002903.jpg' b'001424.jpg']
====================
[b'000194.jpg' b'001345.jpg' b'003553.jpg' b'003031.jpg' b'001821.jpg'
b'003232.jpg' b'000852.jpg' b'003112.jpg' b'000841.jpg' b'002522.jpg'
b'004280.jpg' b'000538.jpg' b'000830.jpg' b'003934.jpg' b'003596.jpg'
b'002004.jpg' b'000794.jpg' b'002015.jpg' b'002620.jpg' b'001974.jpg'
b'003632.jpg' b'002461.jpg' b'003142.jpg' b'000532.jpg' b'001941.jpg'
b'002232.jpg' b'000924.jpg' b'003882.jpg' b'000489.jpg' b'002766.jpg'
b'001795.jpg' b'003866.jpg']
====================
Process finished with exit code 0
總結一下,使用tf.data.Dataset的主要步驟有:
1. 使用tf.data.Dataset.from_tensor_slices從輸入的tensor'中獲取數據,這個tensor可以是圖片的路徑組成的list;
2. 使用dataset.filter選擇是否過濾數據;
3. 使用dataset.map解析數據,如果讀入的只是路徑,那么使用這個函數可以將路徑對應的圖片數據讀進來;
4. 使用dataset.shuffle打亂數據;
5. 使用dataset.batch生成batches;
6. 使用dataset.repeat重復數據集;
7. 使用dataset.prefetch設置在數據請求前預加載的batch數量;
8. 使用iterator =?dataset.make_initializable_iterator(),?batch_data = iterator.get_next()建立迭代器
9. 在session中初始化迭代器sess.run(iterator.initializer),并通過迭代器的get_next()獲取一個batch的數據。
?
2.tfrecord
2.1 使用tfrecord的原因
?? ?正常情況下我們訓練文件夾經常會生成 train, test 或者val文件夾,這些文件夾內部往往會存著成千上萬的圖片或文本等文件,這些文件被散列存著,這樣不僅占用磁盤空間,并且再被一個個讀取的時候會非常慢,繁瑣。占用大量內存空間(有的大型數據不足以一次性加載)。此時我們TFRecord格式的文件存儲形式會很合理的幫我們存儲數據。
?? ?TFRecord內部使用了“Protocol Buffer”二進制數據編碼方案,它只占用一個內存塊,只需要一次性加載一個二進制文件的方式即可,簡單,快速,尤其對大型訓練數據很友好。而且當我們的訓練數據量比較大的時候,可以將數據分成多個TFRecord文件,來提高處理效率。
2.2 tfrecord的寫入
# 聲明一個TFRecordWriter,才能將信息寫入TFRecord文件
# 其中output表示為存儲的路徑,如“output.tfrecord”
writer = tf.python_io.TFRecordWriter(output)
?
# 讀取圖片并進行解碼, input是圖片路徑
image = Image.open("image.jpg")
shape = image.shape
# 將圖片轉換成 string。
image_data = image.tostring()
name = bytes("cat", encoding='utf8')
?
# 創建Example對象,并將Feature一一填充進去
example = tf.train.Example(features=tf.train.Features(feature={
? ?'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
? ?'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
? ?'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
? ? }
? ?))
?? ?? ?
# 將 example 序列化成 string 類型,然后寫入
writer.write(example.SerializeToString())
?
# 全部寫完之后,關閉writer
writer.close()
總結一下,tfrecord主要分為4步:
聲明TFRecordWriter;
創建Example對象,并將Feature存入;
將Example序列化成string類型寫入;
關閉writer
2.3 tfrecord的讀取
# 定義一個reader
reader = tf.TFRecordReader()
# 讀取tfrecord文件,得到一個filename_queue【中括號必須保存下來】
filename_queue=tf.train.string_input_producer(['titanic_train.tfrecords'])
# 返回文件名和文件
_,serialized_example=reader.read(filename_queue)
# 上面的serialized_example是無法直接查看的,需要去按照特征進行解析
features = tf.parse_single_example(serialized_example,features={
? ? 'imgae': tf.FixedLenFeature([], tf.string)
? ? 'label': tf.FixedLenFeature([], tf.string)
})
?
# 每次將數據包裝成一個batch,capacity為隊列能夠容納的最大元素個數
image, label = tf.train.shuffle_batch([features['image'], features['label']], batch_size=16, capacity=500)
?
with tf.Session() as sess:
?? ?tf.global_variables_initializer().run()
?? ?# 創建 Coordinator, 負責實現數據輸入線程的同步
?? ?coord = tf.train.Coordinator()
?? ?# 啟動隊列
?? ?threads=tf.train.start_queue_runners(sess=sess, coord)
? ? # 喂數據實現訓練
? ? img, lab = sess.run([image, label])
?
? ? # 線程同步
?? ?coord.request_stop()
????coord.join(threads=threads)
?
總結一下,tfrecord讀取數據的步驟主要有:
使用tf.TFRecordReader()定義reader
使用tf.train.string_input_producer讀取tfrecord文件
使用reader.read讀取第二部返回的文件名隊列
使用tf.parse_single_example解析文件名隊列
使用tf.train.shuffle_batch將數據打亂并包裝成一個batch
在session通過tf.train.Coordinator()實現線程同步
threads=tf.train.start_queue_runners(sess=sess, coord)啟動隊列
通過coord.request_stop()和coord.join(threads=threads)實現線程同步
3.兩種方式的區別
? ? ? ?tfrecord需要提前將數據存成tfrecord文件,這樣可以減少每次打開文件的時間消耗,針對于大規模數據集訓練模型上有幫助。但是問題在于這種方式比較死板,如果有新的數據集,就需要繼續生成tfrecord文件。
?
? ? ?而tf.data.Dataset的方式就比較靈活了,采用pipeline的方式,在GPU訓練數據時,CPU準備數據,不需要提前生成其他文件。
?
? ? ?總之,這兩種方式都可以處理大規模數據集的訓練,但個人覺得tf.data.Dataset要好用一些。
?
參考資料:
TensorFlow之tfrecords文件詳細教程(https://blog.csdn.net/qq_27825451/article/details/83301811)
【Tensorflow】你可能無法回避的 TFRecord 文件格式詳細講解(https://blog.csdn.net/briblue/article/details/80789608)
TensorFlow數據讀取機制:文件隊列 tf.train.slice_input_producer和 tf.data.Dataset機制(https://blog.csdn.net/guyuealian/article/details/85106012)
?
?
?
?
?
?
?
?
?
?
總結
以上是生活随笔為你收集整理的tf.data.Dataset与tfrecord学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 云的世界
- 下一篇: 汉语拼音标注,汉字加拼音