在pytorch中自定义dataset读取数据2021-1-8学习笔记
生活随笔
收集整理的這篇文章主要介紹了
在pytorch中自定义dataset读取数据2021-1-8学习笔记
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
在pytorch中自定義dataset讀取數據
utils
import os import json import pickle import randomimport matplotlib.pyplot as pltdef read_split_data(root: str, val_rate: float = 0.2):# val_rate劃分驗證集的比例random.seed(0) # 保證隨機結果可復現 #隨機種子設置為0,大家劃分的是一樣的assert os.path.exists(root), "dataset root: {} does not exist.".format(root) #不存在路徑報錯# 遍歷文件夾,一個文件夾對應一個類別flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]#不是文件夾丟棄# 排序,保證順序一致flower_class.sort()# 生成類別名稱以及對應的數字索引class_indices = dict((k, v) for v, k in enumerate(flower_class))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)train_images_path = [] # 存儲訓練集的所有圖片路徑train_images_label = [] # 存儲訓練集圖片對應索引信息val_images_path = [] # 存儲驗證集的所有圖片路徑val_images_label = [] # 存儲驗證集圖片對應索引信息every_class_num = [] # 存儲每個類別的樣本總數supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后綴類型# 遍歷每個文件夾下的文件for cla in flower_class:cla_path = os.path.join(root, cla) #獲得該類別的路徑# 遍歷獲取supported支持的所有文件路徑images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)if os.path.splitext(i)[-1] in supported]#splitext(i)[-1]分割出文件名稱和后綴名 然后用in判斷是否在supported里# 獲取該類別對應的索引image_class = class_indices[cla]# 記錄該類別的樣本數量every_class_num.append(len(images))# 按比例隨機采樣驗證樣本val_path = random.sample(images, k=int(len(images) * val_rate))for img_path in images:if img_path in val_path: # 如果該路徑在采樣的驗證集樣本中則存入驗證集val_images_path.append(img_path)val_images_label.append(image_class)else: # 否則存入訓練集train_images_path.append(img_path)train_images_label.append(image_class)print("{} images were found in the dataset.".format(sum(every_class_num)))plot_image = Falseif plot_image:# 繪制每種類別個數柱狀圖plt.bar(range(len(flower_class)), every_class_num, align='center')# 將橫坐標0,1,2,3,4替換為相應的類別名稱plt.xticks(range(len(flower_class)), flower_class)# 在柱狀圖上添加數值標簽for i, v in enumerate(every_class_num):plt.text(x=i, y=v + 5, s=str(v), ha='center')# 設置x坐標plt.xlabel('image class')# 設置y坐標plt.ylabel('number of images')# 設置柱狀圖的標題plt.title('flower class distribution')plt.show()return train_images_path, train_images_label, val_images_path, val_images_labeldef plot_data_loader_image(data_loader):batch_size = data_loader.batch_sizeplot_num = min(batch_size, 4)json_path = './class_indices.json'assert os.path.exists(json_path), json_path + " does not exist."json_file = open(json_path, 'r')class_indices = json.load(json_file)for data in data_loader:images, labels = datafor i in range(plot_num):# [C, H, W] -> [H, W, C] transpose調整順序img = images[i].numpy().transpose(1, 2, 0)# 反Normalize操作img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255label = labels[i].item()plt.subplot(1, plot_num, i+1)plt.xlabel(class_indices[str(label)])plt.xticks([]) # 去掉x軸的刻度plt.yticks([]) # 去掉y軸的刻度plt.imshow(img.astype('uint8'))plt.show()def write_pickle(list_info: list, file_name: str):with open(file_name, 'wb') as f:pickle.dump(list_info, f)def read_pickle(file_name: str) -> list:with open(file_name, 'rb') as f:info_list = pickle.load(f)return info_listmydataset
from PIL import Image import torch from torch.utils.data import Datasetclass MyDataSet(Dataset):"""自定義數據集"""def __init__(self, images_path: list, images_class: list, transform=None):#初始化函數self.images_path = images_pathself.images_class = images_classself.transform = transformdef __len__(self):#計算該數據集下所有的樣本個數return len(self.images_path)def __getitem__(self, item):#每次傳入一個索引,就返回該索引對應的圖片以及標簽信息img = Image.open(self.images_path[item])#獲得img的路徑,然后得到PIL格式圖像,pytorch用PIL比openCV好# RGB為彩色圖片,L為灰度圖片if img.mode != 'RGB':raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))#報錯,如果是灰度,就把上一行改成Llabel = self.images_class[item]if self.transform is not None:img = self.transform(img)#對圖像進行預處理return img, label@staticmethod#是個靜態方法def collate_fn(batch):#dataloader會使用# 官方實現的default_collate可以參考# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.pyimages, labels = tuple(zip(*batch))#zip將圖片和圖片放一起,標簽和標簽放一起images = torch.stack(images, dim=0)#拼接,并會在dim=0的維度上進行拼接(就是拼成一個矩陣)labels = torch.as_tensor(labels)#標簽也轉換成tensorreturn images, labelsmain
import osimport torch from torchvision import transformsfrom my_dataset import MyDataSet from utils import read_split_data, plot_data_loader_image# http://download.tensorflow.org/example_images/flower_photos.tgz root = "/home/wz/my_github/data_set/flower_data/flower_photos" # 數據集所在根目錄def main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),#隨機裁剪transforms.RandomHorizontalFlip(),#水平翻轉transforms.ToTensor(),#轉化成tensor格式transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}##這個很重要,可以自己實現#實例化datasettrain_data_set = MyDataSet(images_path=train_images_path,#訓練集圖像列表images_class=train_images_label,#訓練集所有圖像對應的標簽信息transform=data_transform["train"])#預處理方法batch_size = 8nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers'.format(nw))train_loader = torch.utils.data.DataLoader(train_data_set,#從實例化的dataset當中取得圖片,然后打包成一個一個batch,然后輸入網絡進行訓練batch_size=batch_size,shuffle=True,#打亂數據集num_workers=nw,#訓練時建議nw,調試時建議0collate_fn=train_data_set.collate_fn)# plot_data_loader_image(train_loader)for step, data in enumerate(train_loader):images, labels = dataif __name__ == '__main__':main()總結
以上是生活随笔為你收集整理的在pytorch中自定义dataset读取数据2021-1-8学习笔记的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 动态规划和回溯的比较
- 下一篇: 蓝桥备赛第一周2021.1.11 递归