深度学习修炼(二)——数据集的加载
文章目錄
- 致謝
- 2 數據集的加載
- 2.1 框架數據集的加載
- 2.2 自定義數據集
- 2.3 準備數據以進行數據加載器訓練
致謝
Pytorch自帶數據集介紹_godblesstao的博客-CSDN博客_pytorch自帶數據集
2 數據集的加載
與sklearn中的datasets自帶數據集類似,pytorch框架也為我們提供了數據集以便一系列的模型測試。其數據集作為一個類繼承自父類torch.utils.data.Dataset。
2.1 框架數據集的加載
讓我們看看torch為我們提供了什么數據集。數據集種類如下所示:
-
手寫字符識別:EMNIST、MNIST、QMNIST、USPS、SVHN、KMNIST、Omniglot
-
實物分類:Fashion MNIST、CIFAR、LSUN、SLT-10、ImageNet
-
人臉識別:CelebA
-
場景分類:LSUN、Places365
-
用于object detection:SVHN、VOCDetection、COCODetection
-
用于semantic/instance segmentation:
-
語義分割:Cityscapes、VOCSegmentation
-
語義邊界:SBD
-
用于image captioning:Flickr、COCOCaption
-
用于video classification:HMDB51、Kinetics
-
用于3D reconstruction:PhotoTour
-
用于shadow detectors:SBU
以FashionMNIST數據集為例,我們看一下如何加載數據集。
torch.datasets.FashionMNIST(root = “data”,train = True,download = True,transform = ToTensor())
- root是存儲訓練/測試數據的路徑
- train指定訓練或測試數據集,當布爾值為True則為訓練集,當布爾值為False則為測試集
- download=True從互聯網下載數據(如果無法在本地獲得)
- transform指定特征轉換方式,target_transform指定標簽轉換方式
數據集加載完實際上是以類的形式存在的,其不同于sklearn中返回的Bunch。
如果我們想要看看數據集中有啥要怎么做呢?首先,這個數據集是圖像分類數據集,說明里面含有的都是圖像,為此,我們可以使用subplots存放這些圖片。對于這些數據集,我們可以像列表一樣手動索引。如train_data[index]。
import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as pltdef load_data():"""加載數據集"""# 1 訓練數據集的加載train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 測試數據集的加載test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""數據集可視化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 從訓練集中隨機抽出九張圖(九個樣本)for i in range(1, cols * rows + 1):# 設置索引,索引取值為0到訓練集的長度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出對應樣本的圖片和標簽img, label = train_data[sample_idx]# 依次畫于事先指定的九宮格圖上figure.add_subplot(rows, cols, i)# 設置對應圖片的標題plt.title(label_map[label])# 關掉坐標軸plt.axis("off")# 展示圖片plt.imshow(img.squeeze(), cmap="gray")# 釋放畫布plt.show()train_data, test_data = load_data() show_data(train_data)out:
上面用到了一個API,即torch.randint()
torch.randint(low=0, high, size, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
- 用于取隨機整數,返回值為張量
- low:int類型,表明要從分布中提取的最低整數
- high:int類型,表明要從分布中提取的最高整數1
- size:元組類型,表明輸出張量的形狀
- dtype:返回值張量的數據類型
- device:返回張量所需的設備
- requires_grad:布爾類型,表明是否要對返回的張量自動求導。
如:
torch.randint(3, 5, (3,)) tensor([4, 3, 4])意味生成一個一維的3元素向量,其中向量中的元素取值從3-5取。
2.2 自定義數據集
如果你不想使用框架自帶的數據集,那么你可以自己定義一個數據集類。自定義Dataset類必須實現三個函數:__ init __ 、 __ len __ 、__ getitem __。其中圖像部分存儲于一個文件夾中,標簽單獨存儲在CSV文件中。
在接下來的代碼中,讓我們看看如何創建一個自定義數據集。
import os import pandas as pd from torchvision.io import read_imageclass CustomImageDataset(Dataset):def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_labels)def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label對于__ init __ 函數來說,包含加載圖像、注釋文件和兩個轉換的目錄,在這里我們不做過多講解,后面會詳細介紹。
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):self.img_labels = pd.read_csv(annotations_file)self.img_dir = img_dirself.transform = transformself.target_transform = target_transform對于__ len __ 函數,其功能是返回數據集中的樣本數。
def __len__(self):return len(self.img_labels)對于 __ getitem __,其功能是給定索引便能返回對應樣本。
def __getitem__(self, idx):img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])image = read_image(img_path)label = self.img_labels.iloc[idx, 1]if self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, label在自定義這一部分不用過多的去了解,用著用著就會了,就算不會代碼也是通用,需要用的時候看一下復制一下,別搞得自己這么焦慮。
2.3 準備數據以進行數據加載器訓練
在pytorch中,數據加載的核心實際上是torch.utils.data.DataLoader類,它支持對torch數據集的python可迭代,換而言之,DataLoader相當于你拿一個水盆,而dataset相當于泉水。DataLoader可以對小批量數據集進行處理,處理內容包括:
- 地圖樣式和可迭代樣式的數據集
- 自定義數據集加載順序
- 多進程加載數據
- 自動內存固定
其中地圖樣式數據集是指自定義數據集,而可迭代樣式數據集指的是自帶數據集。其他詳情對于初學者來說很不友好,這里不做過多解釋,你可以理解為這就是個科普知識。
我們來看一下這個API吧。
torch.utils.data.DataLoader(數據集, batch_size=1, shuffle=False)
- 用于加載樣本并且進行批處理
- 數據集:要加載的數據集
- batch_size:整數類型,表明每批要加載的樣本數,默認為1
- shuffle:布爾類型,表明是否要洗牌
我們利用上面的API來加載我們上面的Fashion_MNIST吧。
def load_batch_data():"""數據集批處理加載器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloader既然已經將樣本導入加載器,那么我們如何從加載器中讀取數據呢?我們可以根據需要循環訪問數據集。
import torch from torch.utils.data import Dataset from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt from torch.utils.data import DataLoaderdef load_data():"""加載數據集"""# 1 訓練數據集的加載train_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor())# 2 測試數據集的加載test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor())return train_data, test_datadef show_data(train_data):"""數據集可視化"""label_map = {0: "T_Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",}figure = plt.figure(figsize=(8, 8))cols, rows = 3, 3# 從訓練集中隨機抽出九張圖(九個樣本)for i in range(1, cols * rows + 1):# 設置索引,索引取值為0到訓練集的長度sample_idx = torch.randint(len(train_data), size=(1,)).item()# 取出對應樣本的圖片和標簽img, label = train_data[sample_idx]# 依次畫于事先指定的九宮格圖上figure.add_subplot(rows, cols, i)# 設置對應圖片的標題plt.title(label_map[label])# 關掉坐標軸plt.axis("off")# 展示圖片plt.imshow(img.squeeze(), cmap="gray")# 釋放畫布plt.show()def load_batch_data():"""數據集批處理加載器"""train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)return train_dataloader, test_dataloaderdef show_batch_data():"""循環訪問數據加載器"""train_dataloader, test_dataloader = load_batch_data()train_feature, train_labels = next(iter(train_dataloader))print(f"特征大小:{train_feature.size()}")print(f"標簽大小:{train_labels.size()}")img = train_feature[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"label:{label}")train_data, test_data = load_data() # show_data(train_data) show_batch_data() 創作挑戰賽新人創作獎勵來咯,堅持創作打卡瓜分現金大獎總結
以上是生活随笔為你收集整理的深度学习修炼(二)——数据集的加载的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 编程思想演进
- 下一篇: 如何得到别人的上网帐号和密码