pytorch基础知识整理(二)数据加载
pytorch數(shù)據(jù)加載組件位于torch.utils.data中。
from torch.utils.data import DataLoader, Dataset, Sampler1, torch.utils.data.DataLoader
pytorch提供的數(shù)據(jù)加載器,它返回一個(gè)可迭代對象。不使用這個(gè)DataLoader,直接手動把每batch數(shù)據(jù)導(dǎo)入顯存當(dāng)然也可以,但是DataLoader類可以使跑模型和加載數(shù)據(jù)并行進(jìn)行,效率高且更加靈活,所以通常都應(yīng)該用DataLoader來加載數(shù)據(jù)。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)基本上從字面就能看懂各參數(shù)的含義。其中num_workers是指開多進(jìn)程加載數(shù)據(jù),似乎在windows上不支持大于0的數(shù)字。pin_memory是指是否把數(shù)據(jù)固定到內(nèi)存中。
然后再用分別分裝數(shù)據(jù)集并用DataLoader調(diào)用:
訓(xùn)練或推理時(shí)從DataLoader中取數(shù)據(jù)的方法一般如下:
for i in range(epoch):for batch_idx, (data, target) in enumerate(train_loader):if use_gpu:data, target = data.cuda(), target.cuda()建議在訓(xùn)練模型前,先分別運(yùn)行一次僅加載數(shù)據(jù)不跑模型的過程和僅跑模型不加載數(shù)據(jù)的過程,分別記錄兩個(gè)過程的時(shí)間以評估數(shù)據(jù)加載過程的耗時(shí)在訓(xùn)練過程中的比例,并據(jù)此考慮是否采取更復(fù)雜的措施提高數(shù)據(jù)加載速度。
2 torch.utils.data.Dataset
必須要先把數(shù)據(jù)構(gòu)造成dataset類型才能被DataLoader調(diào)用,支持兩種類型,一種是匹配型Dataset類,也就是其中定義了__getitem__()和__len__()方法,這種比較常用;另一種是迭代型IterabelDataset類,也就是其中定義了__iter__()方法的。
2.1 torch.utils.data.TensorDataset
把tensor直接包裝成dataset,通常數(shù)據(jù)不需要處理可直接用,且數(shù)據(jù)量不是太大的情況下使用。
注意:新版本pytorch中data_tensor, target_tensor兩個(gè)參數(shù)名已取消,直接放數(shù)據(jù)就可以,再指名參數(shù)名會報(bào)錯(cuò)。
2.2 torch.utils.data.Dataset
封裝dataset的基本類,可實(shí)現(xiàn)各種情況下非常靈活的數(shù)據(jù)集加載,使用時(shí)需要重寫它的__getitem__和__len__方法。
class DealDataset(Dataset):def __init__(self,mode='train'):X, y, Xt, yt = get_data()if mode=='train':self.x_data = Xself.y_data = yelif mode=='test':self.x_data = Xtself.y_data = ytself.len = self.x_data.shape[0]def __getitem__(self, index):data = self.x_data[index]target = self.y_data[index]return data, targetdef __len__(self):return self.len在完成一項(xiàng)較大的建模工程時(shí),通常需要試驗(yàn)各種各樣的數(shù)據(jù)處理方案,因此數(shù)據(jù)加載方案要被大量修改,為了便于修改的靈活性、代碼的整潔性和避免修改的版本混亂,可以先使用一個(gè)BaseDataset基類確定肯定不會變的文件路徑等內(nèi)容,再使用子類繼承來獲得各種版本的數(shù)據(jù)處理方案。
3 torch.utils.data.Sampler
通常sampler不是必須的,但使用sampler可以更靈活的定義采樣次序,可以使用SequentialSampler順序采樣;RandomSampler隨機(jī)采樣(有放回或無放回);WeightedRandomSampler按權(quán)重隨機(jī)采樣;BatchSampler在一個(gè)batch中封裝一個(gè)其他的采樣器,返回一個(gè)batch大小的index索引。
也可以通過重寫Sampler類或其他子類中的__iter__()方法實(shí)現(xiàn)更靈活的自定義采樣器。
4, torchvision.transforms
對圖像類數(shù)據(jù)進(jìn)行處理時(shí)經(jīng)常用到trochvision.transforms
.Compose(transforms)用來把多種變換組合起來
各種變換有:
.CenterCrop(size) 中心切割
.Resize((224,224)) 尺寸變換
.RandomCrop(size, padding=0) 隨機(jī)中心點(diǎn)切割
.RandomHorizontalFlip() 隨機(jī)水平翻轉(zhuǎn)
.RandomSizedCrop(size, interpolation=2) 隨機(jī)大小切割,然后再resize到size大小
.Pad(padding, fill=0) 四周pad
.Normalize(mean, std) 標(biāo)準(zhǔn)化
.ToTensor() 將PIL.Image或np.ndarray轉(zhuǎn)換為tensor
.Lambda(lambd) 函數(shù)式自定義變換
總結(jié)
以上是生活随笔為你收集整理的pytorch基础知识整理(二)数据加载的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 用户偏好类结构化数据分析题参赛总结
- 下一篇: pytorch基础知识整理(一)自动求导