Pytorch MNIST直接离线加载二进制文件到pytorch
生活随笔
收集整理的這篇文章主要介紹了
Pytorch MNIST直接离线加载二进制文件到pytorch
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
說明:MNIST直接離線加載二進制文件到pytorch
?
''' 直接以下4個文件讀入數據到pytorch中t10k-images-idx3-ubyte.gzt10k-labels-idx1-ubyte.gztrain-images-idx3-ubyte.gztrain-labels-idx1-ubyte.gz''' import os import numpy as np import gzipimport torch.utils.data as Data from torchvision import transformsimport timedataPath = 'E:/MNIST/binary_file'def load_data(data_folder, data_name, label_name):"""data_folder: 文件目錄data_name: 數據文件名label_name:標簽數據文件名"""with gzip.open(os.path.join(data_folder,label_name), 'rb') as lbpath: # rb表示的是讀取二進制數據y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)with gzip.open(os.path.join(data_folder,data_name), 'rb') as imgpath:x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)return (x_train, y_train)class DealDataset(Data.Dataset):"""讀取數據、初始化數據"""def __init__(self, folder, data_name, label_name,transform=None):(train_set, train_labels) = load_data(folder, data_name, label_name) # 其實也可以直接使用torch.load(),讀取之后的結果為torch.Tensor形式self.train_set = train_setself.train_labels = train_labelsself.transform = transformdef __getitem__(self, index):img, target = self.train_set[index], int(self.train_labels[index])if self.transform is not None:img = self.transform(img)return img, targetdef __len__(self):return len(self.train_set)# 實例化這個類,然后我們就得到了Dataset類型的數據,記下來就將這個類傳給DataLoader,就可以了。 trainDataset = DealDataset(dataPath, "train-images-idx3-ubyte.gz","train-labels-idx1-ubyte.gz",transform=transforms.ToTensor()) testDataset = DealDataset(dataPath, "t10k-images-idx3-ubyte.gz","t10k-labels-idx1-ubyte.gz",transform=transforms.ToTensor())# 訓練數據和測試數據的裝載 train_loader = Data.DataLoader(dataset=trainDataset,batch_size=100, # 一個批次可以認為是一個包,每個包中含有100張圖片shuffle=False, )test_loader = Data.DataLoader(dataset=testDataset,batch_size=100,shuffle=False, )if __name__ == '__main__':# 這里trainDataset包含:train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'trainDataset.train_labels.shape:{trainDataset.train_labels.shape}\n')print(f'trainDataset.train_set.shape:{trainDataset.train_set.shape}\n')# 這里train_loader包含:batch_size、dataset等屬性,數據類型分別為int,DealDataset# dataset中又包含train_labels, train_set等屬性; 數據類型均為ndarrayprint(f'train_loader.batch_size: {train_loader.batch_size}\n')print(f'train_loader.dataset.train_labels.shape: {train_loader.dataset.train_labels.shape}\n')print(f'train_loader.dataset.train_set.shape: {train_loader.dataset.train_set.shape}\n')?
運行結果
總結
以上是生活随笔為你收集整理的Pytorch MNIST直接离线加载二进制文件到pytorch的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Pytorch cifar100离线加载
- 下一篇: Pytorch Fashion_MNIS