pytorch数据预处理
生活随笔
收集整理的這篇文章主要介紹了
pytorch数据预处理
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
一,數據加載
數據路徑:
打印結果:
二,數據歸一化?
PyTorch提供了torchvision1。它是一個視覺工具包,提供了很多視覺圖像處理的工具,其中transforms模塊提供了對PIL?Image對象和Tensor對象的常用操作。
對PIL Image的操作包括:
?
- Scale:調整圖片尺寸,長寬比保持不變
- CenterCrop、RandomCrop、RandomResizedCrop: 裁剪圖片
- Pad:填充
- ToTensor:將PIL Image對象轉成Tensor,會自動將[0, 255]歸一化至[0, 1]
- transforms.ColorJitter(0.3, 0.3, 0.2) 顏色抖動
- transforms.RandomRotation(10)隨機旋轉
對Tensor的操作包括:
?
- Normalize:標準化,即減均值,除以標準差
- ToPILImage:將Tensor轉為PIL Image對象
三,利用fer2013數據集進行預處理
數據集地址:https://download.csdn.net/download/fanzonghao/11183885
''' Fer2013 Dataset class''' from __future__ import print_function from PIL import Image import numpy as np import h5py import torch.utils.data as data import cv2 import torchvision.transforms as transforms# 定義對數據的預處理 transform = transforms.Compose([transforms.ToTensor(), # 轉為Tensor 歸一化至0~1transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 歸一化]) class FER2013(data.Dataset):"""`FER2013 Dataset.Args:train (bool, optional): If True, creates dataset from training set, otherwisecreates from test set.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``"""def __init__(self, path,split='Training', transform=None):self.transform = transformself.split = split # training set or test setself.data = h5py.File(path, 'r', driver='core')# now load the picked numpy arraysif self.split == 'Training':self.train_data = self.data['Training_pixel']self.train_labels = self.data['Training_label']self.train_data = np.asarray(self.train_data)self.train_data = self.train_data.reshape((28709, 48, 48))elif self.split == 'PublicTest':self.PublicTest_data = self.data['PublicTest_pixel']self.PublicTest_labels = self.data['PublicTest_label']self.PublicTest_data = np.asarray(self.PublicTest_data)self.PublicTest_data = self.PublicTest_data.reshape((3589, 48, 48))else:self.PrivateTest_data = self.data['PrivateTest_pixel']self.PrivateTest_labels = self.data['PrivateTest_label']self.PrivateTest_data = np.asarray(self.PrivateTest_data)self.PrivateTest_data = self.PrivateTest_data.reshape((3589, 48, 48))def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""if self.split == 'Training':img, target = self.train_data[index], self.train_labels[index]elif self.split == 'PublicTest':img, target = self.PublicTest_data[index], self.PublicTest_labels[index]else:img, target = self.PrivateTest_data[index], self.PrivateTest_labels[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = img[:, :, np.newaxis]img = np.concatenate((img, img, img), axis=2)img = Image.fromarray(img)if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):if self.split == 'Training':return len(self.train_data)elif self.split == 'PublicTest':return len(self.PublicTest_data)else:return len(self.PrivateTest_data)if __name__ == '__main__':train_data=FER2013(path='./data/data.h5',split='Training',transform=transform)train_loader = data.DataLoader(dataset=train_data,batch_size=8,shuffle=True,num_workers=2)print(len(train_data))# for i,(img,label) in enumerate(train_data):# if i<1:# img=np.transpose(np.array(img),(1,2,0))# print(img.shape)# img=(img*0.5+0.5)*255# cv2.imwrite('1.jpg',img)# print(label.shape)for i,(img, label) in enumerate(train_loader):if i<1:print('train')img=np.transpose(np.array(img)[0],(1,2,0))img = (img * 0.5 + 0.5) * 255cv2.imwrite('2.jpg',img)結果:
?
總結
以上是生活随笔為你收集整理的pytorch数据预处理的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: OpenCV——绘制基本图形
- 下一篇: JS入门程序(一)