pytorch dataset_【小白学PyTorch】16.TF2读取图片的方法
?擴展之tensorflow2.0 | 15 TF2實現一個簡單的服裝分類任務
小白學PyTorch | 14 tensorboardX可視化教程
小白學PyTorch | 13 EfficientNet詳解及PyTorch實現
小白學PyTorch | 12 SENet詳解及PyTorch實現
小白學PyTorch | 11 MobileNet詳解及PyTorch實現
小白學PyTorch | 10 pytorch常見運算詳解
小白學PyTorch | 9 tensor數據結構與存儲結構
小白學PyTorch | 8 實戰之MNIST小試牛刀
小白學PyTorch | 7 最新版本torchvision.transforms常用API翻譯與講解
小白學PyTorch | 6 模型的構建訪問遍歷存儲(附代碼)
小白學PyTorch | 5 torchvision預訓練模型與數據集全覽
小白學PyTorch | 4 構建模型三要素與權重初始化
小白學PyTorch | 3 淺談Dataset和Dataloader
小白學PyTorch | 2 淺談訓練集驗證集和測試集
小白學PyTorch | 1 搭建一個超簡單的網絡
小白學PyTorch | 動態圖與靜態圖的淺顯理解
參考目錄:
1 PIL讀取圖片
2 TF讀取圖片
3 TF構建數據集
本文的代碼已經上傳,在作者公眾號后臺回復【PyTorch】獲取。
1 PIL讀取圖片
想要把一個圖片,轉換成RGB3通道的一個張量,我們怎么做呢?大家第一反應應該是PIL這個庫吧
from?PIL?import?Imageimport?numpy?as?np
image?=?Image.open('./bug1.jpg')
image.show()
展示的圖片:
然后我們這個image現在是PIL格式的,我們使用numpy.array()來將其轉換成numpy的張量的形式:
image?=?np.array(image)print(image.shape)
>>>(326,?312,?3)
可以看到,這個第三維度是3。對于pytorch而言,數據的第一維度應該是樣本數量,第二維度是通道數,第三四是圖像的寬高,因此PIL讀入的圖片,往往需要把通道數的這個維度移動到第二維度上才能對接上pytorch的形式。(transpose方法來實現這個功能,這里不細說)
2 TF讀取圖片
下面是重點啦,對于tensorflow,tf中自己帶了一個解碼函數,先看一下我的文件目錄:
import?tensorflow?as?tfimages?=?tf.io.gfile.glob('./*.jpeg')
print(images,type(images))
>?['.\\bug1.jpeg',?'.\\bug2.jpeg']?<class?'list'>
可以看出來:
- 這個tensorflow.io.gfile.glob()是讀取路徑下的所有符合條件的文件,并且把路徑做成一個list返回;
- 這個功能也可以用glob庫函數實現,我記得是glob.glob()方法;
- 這里的bug1和bug2其實是同一張圖片,都是上面的那個小兔子。
image?=?tf.image.decode_jpeg(image,channels=3)
print(image.shape,type(image))
>?(326,?312,?3)?<class?'tensorflow.python.framework.ops.EagerTensor'>
需要注意的是:
- tf.io.read_file()這個得到的返回值是二進制格式,所以需要下面的tf.image.decode_jpeg進行一個解碼;
- decode_jpeg的第一個參數就是讀取的二進制文件,然后channels是輸出的圖片的通道數,3就是RPB三個通道,如果是1的話,就是灰度圖片,ratio是圖片大小的一個縮小比例,默認是1,可以是2和4,一會看一下ratio=2的情況;
- 這個image的type是一個tensorflow特別的Tensor的形式,而不是pytorch的那種tensor的形式了。
image?=?tf.image.decode_jpeg(image,channels=1,ratio=2)
print(image.shape,type(image))
>?(163,?156,?1)?<class?'tensorflow.python.framework.ops.EagerTensor'>
寬高都變成了原來的一半,然后通道數是1,都和預想的一樣。使用decode_jpeg等解碼函數得到的結果,是uint8的類型的,簡單地說就是整數,0到255范圍的。在對圖片進行操作的時候,我們需要將其標準化到0到1區間的,因此需要將其轉換成float32類型的。所以對上述代碼進行補充:
image?=?tf.io.read_file('./bug1.jpeg')image?=?tf.image.decode_jpeg(image,channels=1,ratio=2)
print(image.shape,type(image))
image?=?tf.image.resize(image,[256,256])?#?統一圖片大小
image?=?tf.cast(image,tf.float32)?#?轉換類型
image?=?image/255?#?歸一化
print(image)
從結果來看,數據類型已經改變:
3 TF構建數據集
下面是dataset更正式的寫法,關于TF2的問題,不要百度!百度到的都是TF1的解答,看的我暈死了,TF的API的結構真是不太友好。。。
def?read_image(path):????image?=?tf.io.read_file(path)
????image?=?tf.image.decode_jpeg(image,?channels=3,?ratio=1)
????image?=?tf.image.resize(image,?[256,?256])??#?統一圖片大小
????image?=?tf.cast(image,?tf.float32)??#?轉換類型
????image?=?image?/?255??#?歸一化
????return?image
images?=?tf.io.gfile.glob('./*.jpeg')
dataset?=?tf.data.Dataset.from_tensor_slices(images)
AUTOTUNE?=?tf.data.experimental.AUTOTUNE
dataset?=?dataset.map(read_image,num_parallel_calls=AUTOTUNE)
dataset?=?dataset.shuffle(1).batch(1)
for?a?in?dataset.take(2):
????print(a.shape)
代碼中需要注意的是:
- glob獲取一個文件的list,本次就兩個文件名字,一個bug1.jpeg,一個bug2.jpeg;
- tf.data.Dataset.from_tensor_slices()返回的就是一個tensorflow的dataset類型,可以簡單理解為一個可迭代的list,并且有很多其他方法;
- dataset.map就是用實現定義好的函數,對處理dataset中每一個元素,在上面代碼中是把路徑的字符串變成該路徑讀取的圖片張量,對圖片的預處理應該也在這部分進行吧;
- dataset.shuffle就是亂序,.batch()就是把dataset中的元素組裝batch;
- 在獲取dataset中的元素的時候,TF1中有什么迭代器的定義啊,什么iter,但是TF2不用這些,直接.take(num)就行了,這個num就是從dataset中取出來的batch的數量,也就是循環的次數吧。
- AUTOTUNE = tf.data.experimental.AUTOTUNE 就是根據你的cpu的情況,自動判斷多線程的數量。上面代碼的輸出結果為:
往期精彩回顧
適合初學者入門人工智能的路線及資料下載
機器學習及深度學習筆記等資料打印
機器學習在線手冊
深度學習筆記專輯
《統計學習方法》的代碼復現專輯
AI基礎下載
機器學習的數學基礎專輯
獲取一折本站知識星球優惠券,復制鏈接直接打開:
https://t.zsxq.com/662nyZF
本站qq群704220115。
加入微信群請掃碼進群(如果是博士或者準備讀博士請說明):
總結
以上是生活随笔為你收集整理的pytorch dataset_【小白学PyTorch】16.TF2读取图片的方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 安卓md5查看器(安卓md5)
- 下一篇: linux的路径怎么写(linux的路径