Dataset和DataLoader构建数据通道
重點在第二部分的構建數據通道和第三部分的加載數據集
Pytorch通常使用Dataset和DataLoader這兩個工具類來構建數據管道。
Dataset定義了數據集的內容,它相當于一個類似列表的數據結構,具有確定的長度,能夠用索引獲取數據集中的元素。
而DataLoader定義了按batch加載數據集的方法,它是一個實現了__iter__方法的可迭代對象,每次迭代輸出一個batch的數據。
DataLoader能夠控制batch的大小,batch中元素的采樣方法,以及將batch結果整理成模型所需輸入形式的方法,并且能夠使用多進程讀取數據。
在絕大部分情況下,用戶只需實現Dataset的__len__方法和__getitem__方法,就可以輕松構建自己的數據集,并用默認數據管道進行加載。
一,Dataset和DataLoader概述
1,獲取一個batch數據的步驟
讓我們考慮一下從一個數據集中獲取一個batch的數據需要哪些步驟。
(假定數據集的特征和標簽分別表示為張量X和Y,數據集可以表示為(X,Y), 假定batch大小為m)
1,首先我們要確定數據集的長度n。
結果類似:n = 1000。
2,然后我們從0到n-1的范圍中抽樣出m個數(batch大小)。
假定m=4, 拿到的結果是一個列表,類似:indices = [1,4,8,9]
3,接著我們從數據集中去取這m個數對應下標的元素。
拿到的結果是一個元組列表,類似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]
4,最后我們將結果整理成兩個張量作為輸出。
拿到的結果是兩個張量,類似batch = (features,labels),
其中 features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])
2,Dataset和DataLoader的功能分工
上述第1個步驟確定數據集的長度是由 Dataset的__len__ 方法實現的。
第2個步驟從0到n-1的范圍中抽樣出m個數的方法是由 DataLoader的 sampler和 batch_sampler參數指定的。
sampler參數指定單個元素抽樣方法,一般無需用戶設置,程序默認在DataLoader的參數shuffle=True時采用隨機抽樣,shuffle=False時采用順序抽樣。
batch_sampler參數將多個抽樣的元素整理成一個列表,一般無需用戶設置,默認方法在DataLoader的參數drop_last=True時會丟棄數據集最后一個長度不能被batch大小整除的批次,在drop_last=False時保留最后一個批次。
第3個步驟的核心邏輯根據下標取數據集中的元素 是由 Dataset的 __getitem__方法實現的。
第4個步驟的邏輯由DataLoader的參數collate_fn指定。一般情況下也無需用戶設置。
3,Dataset和DataLoader的主要接口
偽代碼,實際應用意義不大
import torch class Dataset(object):def __init__(self):passdef __len__(self):raise NotImplementedErrordef __getitem__(self,index):raise NotImplementedErrorclass DataLoader(object):def __init__(self,dataset,batch_size,collate_fn,shuffle = True,drop_last = False):self.dataset = datasetself.sampler =torch.utils.data.RandomSampler if shuffle else \torch.utils.data.SequentialSamplerself.batch_sampler = torch.utils.data.BatchSamplerself.sample_iter = self.batch_sampler(self.sampler(range(len(dataset))),batch_size = batch_size,drop_last = drop_last)def __next__(self):indices = next(self.sample_iter)batch = self.collate_fn([self.dataset[i] for i in indices])return batch二,使用Dataset創建數據集
Dataset創建數據集常用的方法有:
使用 torch.utils.data.TensorDataset 根據Tensor創建數據集(numpy的array,Pandas的DataFrame需要先轉換成Tensor)。
使用 torchvision.datasets.ImageFolder 根據圖片目錄創建圖片數據集。
繼承 torch.utils.data.Dataset 創建自定義數據集。
此外,還可以通過
torch.utils.data.random_split 將一個數據集分割成多份,常用于分割訓練集,驗證集和測試集。
調用Dataset的加法運算符(+)將多個數據集合并成一個數據集。
1,根據Tensor創建數據集
2,根據圖片目錄創建圖片數據集
三,使用DataLoader加載數據集
DataLoader能夠控制batch的大小,batch中元素的采樣方法,以及將batch結果整理成模型所需輸入形式的方法,并且能夠使用多進程讀取數據。
DataLoader的函數簽名
DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,multiprocessing_context=None, )一般情況下,我們僅僅會配置 dataset, batch_size, shuffle, num_workers, drop_last這五個參數,其他參數使用默認值即可。
dataset : 數據集
batch_size: 批次大小
shuffle: 是否亂序
sampler: 樣本采樣函數,一般無需設置。
batch_sampler: 批次采樣函數,一般無需設置。
num_workers: 使用多進程讀取數據,設置的進程數。
collate_fn: 整理一個批次數據的函數。
pin_memory: 是否設置為鎖業內存。默認為False,鎖業內存不會使用虛擬內存(硬盤),從鎖業內存拷貝到GPU上速度會更快。
drop_last: 是否丟棄最后一個樣本數量不足batch_size批次數據。
timeout: 加載一個數據批次的最長等待時間,一般無需設置。
worker_init_fn: 每個worker中dataset的初始化函數,常用于 IterableDataset。一般不使用。
總結
以上是生活随笔為你收集整理的Dataset和DataLoader构建数据通道的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: nn.functional 和 nn.M
- 下一篇: 女人梦到自己游泳是什么预兆