Pytorch cifar100离线加载二进制文件
生活随笔
收集整理的這篇文章主要介紹了
Pytorch cifar100离线加载二进制文件
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
說明:直接加載cifar100二進制文件到Pytorch
?
''' 直接加載文件到pytorchmetatesttrain '''import os import cv2 import pickle import time import numpy as np import matplotlib.pyplot as pltimport torchvision from torch.autograd import Variable import torch.utils.data as Data from torchvision import transformsdef load_CIFAR_100(root, train=True, fine_label=True):"""root,文件名train 訓練數據集時取True,測試集時取Falsefine_label 如果分類為100類時取True,分類為20類時取False"""if train:filename = root + 'train'else:filename = root + 'test'with open(filename, 'rb')as f:datadict = pickle.load(f,encoding='bytes')X = datadict[b'data']if train:# [50000, 32, 32, 3]X = X.reshape(50000, 3, 32, 32).transpose(0,2,3,1)else:# [10000, 32, 32, 3]X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1)# fine_labels細分類,共100中類別# coarse_labels超級類,共20中類別,每個超級類中實際包含5種fine_labels# 如trees類中,又包含maple, oak, palm, pine, willow,5種具體的樹# 這里只取fine_labels# Y = datadict[b'coarse_labels']+datadict[b'fine_labels']if fine_label:Y = datadict[b'fine_labels']else:Y = datadict[b'coarse_labels']Y = np.array(Y)return X, Yclass DealDataset(Data.Dataset):"""讀取數據、初始化數據"""def __init__(self, root, train=True, fine_label=True, transform=None):# 其實也可以直接使用torch.load(),讀取之后的結果為torch.Tensor形式self.x, self.y = load_CIFAR_100(root, train=train, fine_label=fine_label)self.transform = transformself.train = traindef __getitem__(self, index):img, target = self.x[index], int(self.y[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.x)root = r'E:\cifar-100-python' + '/' batch_size = 20# 實例化這個類,然后我們就得到了Dataset類型的數據,記下來就將這個類傳給DataLoader,就可以了。 trainDataset = DealDataset(root, train=True, fine_label=True, transform=transforms.ToTensor()) testDataset = DealDataset(root, train=False, fine_label=True, transform=transforms.ToTensor())# 訓練數據和測試數據的裝載 train_loader = Data.DataLoader(dataset=trainDataset,batch_size=batch_size, # 一個批次可以認為是一個包,每個包中含有batch_size張圖片shuffle=False, )test_loader = Data.DataLoader(dataset=testDataset,batch_size=batch_size,shuffle=False, )if __name__ == '__main__':# 這里trainDataset包含:train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'trainDataset.y.shape:{trainDataset.y.shape}\n')print(f'trainDataset.y.shape:{trainDataset.x.shape}\n')# 這里train_loader包含:batch_size、dataset等屬性,數據類型分別為int,DealDataset# dataset中又包含train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.y.shape: {train_loader.dataset.y.shape}\n')print(f'train_loader.dataset.x.shape: {train_loader.dataset.x.shape}\n')# # --可視化1,使用OpenCV----------------------------------------------images, lables = next(iter(train_loader))img = torchvision.utils.make_grid(images, nrow = 10)img = img.numpy().transpose(1, 2, 0)# OpenCV默認為BGR,這里img為RGB,因此需要對調img[:,:,::-1]cv2.imshow('img', img[:,:,::-1])cv2.waitKey(0)?
運行結果:
顯示圖
總結
以上是生活随笔為你收集整理的Pytorch cifar100离线加载二进制文件的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch cifar10离线加
- 下一篇: Pytorch MNIST直接离线加载二