Pytorch Fashion_MNIST直接离线加载二进制文件到pytorch
生活随笔
收集整理的這篇文章主要介紹了
Pytorch Fashion_MNIST直接离线加载二进制文件到pytorch
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
說明:Fashion_MNIST直接離線加載二進制文件到pytorch
''' 將4個gz直接加載到pytoch用來訓練t10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gztrain-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gz '''import os import numpy as np import gzip import matplotlib.pyplot as pltimport torch import torch.utils.data as Data from torchvision import datasets, transforms from torch.autograd import Variableimport timedataPath = 'E:/fashion_binary_gz/'# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")batch_size = 4def load_data(data_folder, data_name, label_name):"""data_folder: 文件目錄data_name: 數據文件名label_name:標簽數據文件名"""with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是讀取二進制數據y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)class DealDataset(Data.Dataset):"""讀取數據、初始化數據"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name) # 其實也可以直接使用torch.load(),讀取之后的結果為torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)# 實例化這個類,然后我們就得到了Dataset類型的數據,記下來就將這個類傳給DataLoader,就可以了。 trainDataset = DealDataset(dataPath,"train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor())testDataset = DealDataset(dataPath,"t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 訓練數據和測試數據的裝載 train_loader = Data.DataLoader(dataset=trainDataset,batch_size=100, # 一個批次可以認為是一個包,每個包中含有100張圖片shuffle=False, )test_loader = Data.DataLoader(dataset=testDataset,batch_size=100,shuffle=False, )if __name__ == '__main__':# 這里trainDataset包含:train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 這里train_loader包含:batch_size、dataset等屬性,數據類型分別為int,DealDataset# dataset中又包含train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')dataiter = iter(train_loader)images, labels = dataiter.next()images = images.numpy()classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# plot the images in the batch, along with the corresponding labelsfig = plt.figure(figsize=(25, 4))for idx in np.arange(batch_size):ax = fig.add_subplot(2, batch_size/2, idx+1, xticks=[], yticks=[])# ax.imshow(np.squeeze(images[idx]), cmap='gray')ax.imshow(np.squeeze(images[idx]), cmap='gray')ax.set_title(classes[labels[idx]])plt.show()?
運行結果
顯示圖像
總結
以上是生活随笔為你收集整理的Pytorch Fashion_MNIST直接离线加载二进制文件到pytorch的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch MNIST直接离线加载二
- 下一篇: matplotlib markers的类