tensorflow 图片批处理--- tf.train.batch
生活随笔
收集整理的這篇文章主要介紹了
tensorflow 图片批处理--- tf.train.batch
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
當我們使用tensorflow進行深度學習時,進行訓練模型時,我們往往要讀取大量的圖片進行批處理輸入模型進行訓練.
如果我們一次性讀取全部圖片或者過多張圖片,內存將有可能溢出.
如果我們一次讀取小批量圖片,再將圖片轉換成tensor,然后再輸入模型,則隨著模型的迭代次數增大,內存占用將越來越大,最終內存溢出.如下代碼:
sess=tf.Session()
ImgFiles= ***** (包括所有訓練集圖片的文件名)
for imgFile in imgFiles:
img=scipy.misc.imread(imgFile) #讀取圖片
img=tf.convert_to_tensor(img,dtype='float32') #將圖片轉化成tensor
img=preprocessing(img) #圖片預處理
res=net(img) #將圖片輸入網絡模型進行訓練,得出結果
如上代碼,因為tensor結點是不會自動回收的,即使你變量名被覆蓋,原來的tensor結點依然占用內存,最終內存占用將越來越大,所以不要在循環里面生成tensor.
可通過如下方法檢測是否不斷生成計算節點
在sess里面,循環外面,使用graph.finalize()鎖定graph.如果運行時保存,則說明有計算節點加入.
所以,我們要使用tf.train.batch進行圖片讀取訓練訓練.
代碼如下:
def read(Path):
filenames = [join(Path, f) for f in listdir(Path) if isfile(join(Path, f))] #Path為圖片訓練集的文件夾路徑,返回的是所有訓練集圖片的路徑的集合
filename_queue = tf.train.string_input_producer(filenames, shuffle=True, num_epochs=10) #將圖片產生一個隊列,可控制是否排序,圖片的迭代次數
reader = tf.WholeFileReader() #產生一個讀取器reader
_, img_bytes = reader.read(filename_queue) #將隊列輸入讀取器reader當中,讀取序列
image = tf.image.decode_png(img_bytes, channels=3) #對序列解碼,現在image還是一張圖片,為tensor
image=preprocessing(image) #對圖片進行預處理
image=tf.train.batch([image], 2, dynamic_pad=True) #將圖片合并生成一個批次,第二個參數2是控制這個批次包含多少張圖片.
with tf.Graph().as_default() as g: with tf.Session() as sess: #協調器要求在with tf.Session() as sess 里面使用. img=read(Path) coord = tf.train.Coordinator() #創建一個協調器,管理線程 threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): #進行模型訓練 finally: coord.request_stop() coord.join(threads)
with tf.Graph().as_default() as g: with tf.Session() as sess: #協調器要求在with tf.Session() as sess 里面使用. img=read(Path) coord = tf.train.Coordinator() #創建一個協調器,管理線程 threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): #進行模型訓練 finally: coord.request_stop() coord.join(threads)
總結
以上是生活随笔為你收集整理的tensorflow 图片批处理--- tf.train.batch的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 7n65场效应管参数-7N65参数代换K
- 下一篇: 第五届信息科学、电气与自动化工程国际学术