PyTorch 之 Datasets
生活随笔
收集整理的這篇文章主要介紹了
PyTorch 之 Datasets
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
實現一個定制的 Dataset 類
Dataset 類是 PyTorch 圖像數據集中最為重要的一個類,也是 PyTorch 中所有數據集加載類中應該繼承的父類。其中,父類的兩個私有成員函數必須被重載。
- getitem(self, index) # 支持數據集索引的函數
- len(self) # 返回數據集的大小
Datasets 的框架:
class CustomDataset(data.Dataset): # 需要繼承 data.Datasetdef __init__(self):# TODO# Initialize file path or list of file names.passdef __getitem__(self, index):# TODO# 1. 從文件中讀取指定 index 的數據(例:使用 numpy.fromfile, PIL.Image.open)# 2. 預處理讀取的數據(例:torchvision.Transform)# 3. 返回數據對(例:圖像和對應標簽)passdef __len__(self):# TODO# You should change 0 to the total size of your dataset.return 0舉例:
class MyDataset(Dataset):"""root: 圖像存放地址根路徑augment:是否需要圖像增強"""def __init__(self, root, augment=None):# 這個 list 存放所有圖像的地址self.image_files = np.array([x.path for x in os.scandir(root)if x.name.endswith(".jpg") or x.name.endswith(".png") or x.name.endswith(".JPG")])self.augment = augmentdef __getitem__(self, index):if self.augment:image = open_image(self.image_files[index]) # 這里的 open_image 是讀取圖像的函數,可以用 PIL 或者 OpenCV 等庫進行讀取image = self.augment(image) # 這里對圖像進行了數據增強return to_tensor(image) # PyTorch 中得到的圖像必須是 tensorelse:image = open_image(self.image_files[index])return to_tensor(image)下面是官方 MNIST 的例子:
class MNIST(data.Dataset):"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.Args:root (string): Root directory of dataset where ``processed/training.pt``and ``processed/test.pt`` exist.train (bool, optional): If True, creates dataset from ``training.pt``,otherwise from ``test.pt``.download (bool, optional): If true, downloads the dataset from the internet andputs it in root directory. If dataset is already downloaded, it is notdownloaded again.transform (callable, optional): A function/transform that takes in an PIL imageand returns a transformed version. E.g, ``transforms.RandomCrop``target_transform (callable, optional): A function/transform that takes in thetarget and transforms it."""urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz','http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz','http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz','http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',]raw_folder = 'raw'processed_folder = 'processed'training_file = 'training.pt'test_file = 'test.pt'classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four','5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']class_to_idx = {_class: i for i, _class in enumerate(classes)}@propertydef targets(self):if self.train:return self.train_labelselse:return self.test_labelsdef __init__(self, root, train=True, transform=None, target_transform=None, download=False):self.root = os.path.expanduser(root)self.transform = transformself.target_transform = target_transformself.train = train # training set or test setif download:self.download()if not self._check_exists():raise RuntimeError('Dataset not found.' +' You can use download=True to download it')if self.train:self.train_data, self.train_labels = torch.load(os.path.join(self.root, self.processed_folder, self.training_file))else:self.test_data, self.test_labels = torch.load(os.path.join(self.root, self.processed_folder, self.test_file))def __getitem__(self, index):"""Args:index (int): IndexReturns:tuple: (image, target) where target is index of the target class."""if self.train:img, target = self.train_data[index], self.train_labels[index]else:img, target = self.test_data[index], self.test_labels[index]# doing this so that it is consistent with all other datasets# to return a PIL Imageimg = Image.fromarray(img.numpy(), mode='L')if self.transform is not None:img = self.transform(img)if self.target_transform is not None:target = self.target_transform(target)return img, targetdef __len__(self):if self.train:return len(self.train_data)else:return len(self.test_data)def _check_exists(self):return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))def download(self):"""Download the MNIST data if it doesn't exist in processed_folder already."""from six.moves import urllibimport gzipif self._check_exists():return# download filestry:os.makedirs(os.path.join(self.root, self.raw_folder))os.makedirs(os.path.join(self.root, self.processed_folder))except OSError as e:if e.errno == errno.EEXIST:passelse:raisefor url in self.urls:print('Downloading ' + url)data = urllib.request.urlopen(url)filename = url.rpartition('/')[2]file_path = os.path.join(self.root, self.raw_folder, filename)with open(file_path, 'wb') as f:f.write(data.read())with open(file_path.replace('.gz', ''), 'wb') as out_f, \gzip.GzipFile(file_path) as zip_f:out_f.write(zip_f.read())os.unlink(file_path)# process and save as torch filesprint('Processing...')training_set = (read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')))test_set = (read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')))with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:torch.save(training_set, f)with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:torch.save(test_set, f)print('Done!')def __repr__(self):fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())tmp = 'train' if self.train is True else 'test'fmt_str += ' Split: {}\n'.format(tmp)fmt_str += ' Root Location: {}\n'.format(self.root)tmp = ' Transforms (if any): 'fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))tmp = ' Target Transforms (if any): 'fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))return fmt_str轉載于:https://www.cnblogs.com/xxxxxxxxx/p/11429051.html
總結
以上是生活随笔為你收集整理的PyTorch 之 Datasets的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch 之 DataLoader
- 下一篇: 详解python正则\b和\B的区别