Pytorch cifar10离线加载二进制文件
生活随笔
收集整理的這篇文章主要介紹了
Pytorch cifar10离线加载二进制文件
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
?
說明直接離線加載cifar10到Pytorch
''' 直接加載6個文件到pytorchdata_batch_1data_batch_2data_batch_3data_batch_4data_batch_5test_batch'''import os import cv2 import pickle import numpy as np import matplotlib.pyplot as pltimport torchvision from torch.autograd import Variable import torch.utils.data as Data from torchvision import transforms#加載cifar10的數據 def load_CIFAR_batch(filename):""" load single batch of cifar """with open(filename, 'rb') as f:datadict = pickle.load(f,encoding='latin1')X = datadict['data']Y = datadict['labels']# X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1).astype("float")X = X.reshape(10000, 3, 32,32).transpose(0,2,3,1)Y = np.array(Y)return X, Ydef load_CIFAR10(ROOT):""" load all of cifar """xs = []ys = []for b in range(1,6):filename = os.path.join(ROOT, 'data_batch_%d' % (b))X, Y = load_CIFAR_batch(filename)xs.append(X)ys.append(Y)Xtrain = np.concatenate(xs)#使變成行向量Ytrain = np.concatenate(ys)del X, YXtest, Ytest = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))return Xtrain, Ytrain, Xtest, Ytestclass DealDataset(Data.Dataset):"""讀取數據、初始化數據"""def __init__(self, root, train=True, transform=None):if train:# 其實也可以直接使用torch.load(),讀取之后的結果為torch.Tensor形式(train_set, train_labels, _, _) = load_CIFAR10(root)self.train_set = train_setself.train_labels = train_labelselse:(_, _, test_set, test_labels) = load_CIFAR10(root)self.test_set = test_setself.test_labels = test_labelsself.transform = transformself.train = traindef __getitem__(self, index):if self.train:img, target = self.train_set[index], int(self.train_labels[index])else:img, target = self.test_set[index], int(self.test_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):if self.train:return len(self.train_set)else:return len(self.test_set)root = r'E:\cifar-10-python\cifar-10-batches-py' batch_size = 8# 實例化這個類,然后我們就得到了Dataset類型的數據,記下來就將這個類傳給DataLoader,就可以了。 trainDataset = DealDataset(root, train=True, transform=transforms.ToTensor()) testDataset = DealDataset(root, train=False, transform=transforms.ToTensor())# 訓練數據和測試數據的裝載 train_loader = Data.DataLoader(dataset=trainDataset,batch_size=batch_size, # 一個批次可以認為是一個包,每個包中含有batch_size張圖片shuffle=False, )test_loader = Data.DataLoader(dataset=testDataset,batch_size=batch_size,shuffle=False, )if __name__ == '__main__':# 這里trainDataset包含:train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 這里train_loader包含:batch_size、dataset等屬性,數據類型分別為int,DealDataset# dataset中又包含train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')# # 可視化1,使用OpenCV# images, lables = next(iter(train_loader))# img = torchvision.utils.make_grid(images, nrow = 10)# img = img.numpy().transpose(1, 2, 0)# # OpenCV默認為BGR,這里img為RGB,因此需要對調img[:,:,::-1]# cv2.imshow('img', img[:,:,::-1])# cv2.waitKey(0)# 可視化2,使用pltdataiter = iter(train_loader)images, labels = dataiter.next()images = images.numpy()classes = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']fig = plt.figure(figsize=(4, 4))for idx in np.arange(batch_size):ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])# ax.imshow(np.squeeze(images[idx]), cmap='gray')# a = images[idx]# b = images[idx].transpose(1, 2, 0)# ax.imshow(images[idx].transpose(1, 2, 0), cmap='RGB')ax.imshow(images[idx].transpose(1, 2, 0))ax.set_title(classes[labels[idx]])plt.show()?
運行結果
顯示圖
總結
以上是生活随笔為你收集整理的Pytorch cifar10离线加载二进制文件的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: HMM和CRF 条件随机场详解
- 下一篇: Pytorch cifar100离线加载