PyTorch数据Pipeline标准化代码模板
前言
PyTorch作為一款流行深度學習框架其熱度大有超越TensorFlow的感覺。根據此前的統計,目前TensorFlow雖然仍然占據著工業界,但PyTorch在視覺和NLP領域的頂級會議上已呈一統之勢。
這篇文章筆者將和大家聚焦于PyTorch的自定義數據讀取pipeline模板和相關trciks以及如何優化數據讀取的pipeline等。我們從PyTorch的數據對象類Dataset開始。Dataset在PyTorch中的模塊位于utils.data下。
from?torch.utils.data?import?Dataset本文將圍繞Dataset對象分別從原始模板、torchvision的transforms模塊、使用pandas來輔助讀取、torch內置數據劃分功能和DataLoader來展開闡述。
Dataset原始模板
PyTorch官方為我們提供了自定義數據讀取的標準化代碼代碼模塊,作為一個讀取框架,我們這里稱之為原始模板。其代碼結構如下:
from?torch.utils.data?import?Dataset class CustomDataset(Dataset):def __init__(self, ...):# stuffdef __getitem__(self, index):# stuffreturn (img, label)def __len__(self):#?return?examples?sizereturn count根據這個標準化的代碼模板,我們只需要根據自己的數據讀取任務,分別往__init__()、__getitem__()和__len__()三個方法里添加讀取邏輯即可。作為PyTorch范式下的數據讀取以及為了后續的data loader,三個方法缺一不可。其中:
__init__()函數用于初始化數據讀取邏輯,比如讀取包含標簽和圖片地址的csv文件、定義transform組合等。
__getitem__()函數用來返回數據和標簽。目的上是為了能夠被后續的dataloader所調用。
__len__()函數則用于返回樣本數量。
現在我們往這個框架里填幾行代碼來形成一個簡單的數字案例。創建一個從1到100的數字例子:
from?torch.utils.data?import?Dataset class CustomDataset(Dataset):def __init__(self):self.samples = list(range(1, 101))def __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]if __name__ == '__main__':dataset = CustomDataset()print(len(dataset))print(dataset[50])print(dataset[1:100])添加torchvision.transforms
然后我們來看如何從內存中讀取數據以及如何在讀取過程中嵌入torchvision中的transforms功能。torchvision是一個獨立于torch的關于數據、模型和一些圖像增強操作的輔助庫。主要包括datasets默認數據集模塊、models經典模型模塊、transforms圖像增強模塊以及utils模塊等。在使用torch讀取數據的時候,一般會搭配上transforms模塊對數據進行一些處理和增強工作。
添加了tranforms之后的讀取模塊可以改寫為:
from torch.utils.data import Dataset from torchvision import transforms as Tclass CustomDataset(Dataset):def __init__(self, ...):# stuff...# compose the transforms methodsself.transform?=?T.Compose([T.CenterCrop(100),T.ToTensor()])def __getitem__(self, index):# stuff...data?=?#?Some?data?read?from?a?file?or?image# execute the transformdata = self.transform(data)return (img, label)def __len__(self):# return examples sizereturn countif __name__ == '__main__':# Call the datasetcustom_dataset = CustomDataset(...)可以看到,我們使用了Compose方法來把各種數據處理方法聚合到一起進行定義數據轉換方法。通常作為初始化方法放在__init__()函數下。我們以貓狗圖像數據為例進行說明。
定義數據讀取方法如下:
class?DogCat(Dataset):????def?__init__(self,?root,?transforms=None,?train=True,?val=False):"""get?images?and?execute?transforms."""self.val?=?valimgs?=?[os.path.join(root,?img)?for?img?in?os.listdir(root)]#?train:?Cats_Dogs/trainset/cat.1.jpg#?val:?Cats_Dogs/valset/cat.10004.jpgimgs?=?sorted(imgs,?key=lambda?x:?x.split('.')[-2])self.imgs?=?imgs?????????if?transforms?is?None:#?normalize??????normalize?=?T.Normalize(mean?=?[0.485,?0.456,?0.406],std?=?[0.229,?0.224,?0.225])#?trainset?and?valset?have?different?data?transform?#?trainset?need?data?augmentation?but?valset?don't.# valsetif?self.val:self.transforms?=?T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])#?trainsetelse:self.transforms?=?T.Compose([T.Resize(256),T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),normalize])def?__getitem__(self,?index):"""return?data?and?label"""img_path?=?self.imgs[index]label?=?1?if?'dog'?in?img_path.split('/')[-1]?else?0data?=?Image.open(img_path)data?=?self.transforms(data)return data, labeldef?__len__(self):"""return?images?size."""return len(self.imgs)if?__name__?==?"__main__":train_dataset?=?DogCat('./Cats_Dogs/trainset/',?train=True)print(len(train_dataset))print(train_dataset[0])因為這個數據集已經分好了訓練集和驗證集,所以在讀取和transforms的時候需要進行區分。運行示例如下:
與pandas一起使用
很多時候數據的目錄地址和標簽都是通過csv文件給出的。如下所示:
此時在數據讀取的pipeline中我們需要在__init__()方法中利用pandas把csv文件中包含的圖片地址和標簽融合進去。相應的數據讀取pipeline模板可以改寫為:
class?CustomDatasetFromCSV(Dataset):def?__init__(self,?csv_path):"""Args:csv_path?(string):?path?to?csv?filetransform:?pytorch?transforms?for?transforms?and?tensor?conversion"""#?Transformsself.to_tensor?=?transforms.ToTensor()#?Read?the?csv?fileself.data_info?=?pd.read_csv(csv_path,?header=None)#?First?column?contains?the?image?pathsself.image_arr?=?np.asarray(self.data_info.iloc[:,?0])#?Second?column?is?the?labelsself.label_arr?=?np.asarray(self.data_info.iloc[:,?1])#?Calculate?lenself.data_len = len(self.data_info.index)def?__getitem__(self,?index):#?Get?image?name?from?the?pandas?dfsingle_image_name?=?self.image_arr[index]#?Open?imageimg_as_img?=?Image.open(single_image_name)#?Transform?image?to?tensorimg_as_tensor?=?self.to_tensor(img_as_img)#?Get?label?of?the?image?based?on?the?cropped?pandas?columnsingle_image_label?=?self.label_arr[index]return (img_as_tensor, single_image_label)def?__len__(self):return self.data_lenif?__name__?==?"__main__":#?Call?datasetdataset = CustomDatasetFromCSV('./labels.csv')以mnist_label.csv文件為示例:
from?torch.utils.data?import?Dataset from?torch.utils.data?import?DataLoader from?torchvision?import?transforms?as?T from?PIL?import?Image import?os import?numpy?as?np import pandas as pdclass?CustomDatasetFromCSV(Dataset):def?__init__(self,?csv_path):"""Args:csv_path?(string):?path?to?csv?file????????????transform:?pytorch?transforms?for?transforms?and?tensor?conversion"""#?Transformsself.to_tensor?=?T.ToTensor()#?Read?the?csv?fileself.data_info?=?pd.read_csv(csv_path,?header=None)#?First?column?contains?the?image?pathsself.image_arr?=?np.asarray(self.data_info.iloc[:,?0])#?Second?column?is?the?labelsself.label_arr?=?np.asarray(self.data_info.iloc[:,?1])#?Third?column?is?for?an?operation?indicatorself.operation_arr?=?np.asarray(self.data_info.iloc[:,?2])#?Calculate?lenself.data_len = len(self.data_info.index)def?__getitem__(self,?index):#?Get?image?name?from?the?pandas?dfsingle_image_name?=?self.image_arr[index]#?Open?imageimg_as_img?=?Image.open(single_image_name)#?Check?if?there?is?an?operationsome_operation?=?self.operation_arr[index]#?If?there?is?an?operationif?some_operation:#?Do?some?operation?on?image#?...#?...pass#?Transform?image?to?tensorimg_as_tensor?=?self.to_tensor(img_as_img)#?Get?label?of?the?image?based?on?the?cropped?pandas?columnsingle_image_label?=?self.label_arr[index]return (img_as_tensor, single_image_label)def?__len__(self):return self.data_lenif?__name__?==?"__main__":transform?=?T.Compose([T.ToTensor()])dataset?=?CustomDatasetFromCSV('./mnist_labels.csv')print(len(dataset))print(dataset[5])運行示例如下:
訓練集驗證集劃分
一般來說,為了模型訓練的穩定,我們需要對數據劃分訓練集和驗證集。torch的Dataset對象也提供了random_split函數作為數據劃分工具,且劃分結果可直接供后續的DataLoader使用。
以kaggle的花朵數據為例:
from?torch.utils.data?import?DataLoader from?torchvision.datasets?import?ImageFolder from?torchvision?import?transforms?as?T from torch.utils.data import random_splittransform?=?T.Compose([T.Resize((224,?224)),T.RandomHorizontalFlip(),T.ToTensor()])dataset?=?ImageFolder('./flowers_photos',?transform=transform) print(dataset.class_to_idx)trainset,?valset?=?random_split(dataset,?[int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])trainloader?=?DataLoader(dataset=trainset,?batch_size=32,?shuffle=True,?num_workers=1) for?i,?(img,?label)?in?enumerate(trainloader):img,?label?=?img.numpy(),?label.numpy()print(img, label)valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1) for?i,?(img,?label)?in?enumerate(trainloader):img,?label?=?img.numpy(),?label.numpy()print(img.shape, label)這里使用了ImageFolder模塊,可以直接讀取各標簽對應的文件夾,部分運行示例如下:
使用DataLoader
dataset方法寫好之后,我們還需要使用DataLoader將其逐個喂給模型。上一節的數據劃分我們已經用到了DataLoader函數。從本質上來講,DataLoader只是調用了__getitem__()方法并按批次返回數據和標簽。使用方法如下:
from?torch.utils.data?import?DataLoader from torchvision import transforms as Tif?__name__?==?"__main__":#?Define?transformstransformations?=?T.Compose([T.ToTensor()])#?Define?custom?datasetdataset?=?CustomDatasetFromCSV('./labels.csv')#?Define?data?loaderdata_loader?=?DataLoader(dataset=dataset,?batch_size=10,?shuffle=True)for?images,?labels?in?data_loader:# Feed the data to the model以上就是PyTorch讀取數據的Pipeline主要方法和流程。基于Dataset對象的基本框架不變,具體細節可自定義化調整。
本文原創首發于公眾號【機器學習實驗室】,開創了【深度學習60講】、【機器學習算法手推30講】和【深度學習100問】三大系列文章。
一個算法工程師的成長之路
長按二維碼.關注機器學習實驗室
機器學習實驗室的近期文章:
機器學習公式推導和算法手寫之XGBoost
機器學習公式推導和算法手寫之馬爾科夫鏈蒙特卡洛
如何部署一個輕量級深度學習項目?
基于C++的PyTorch模型部署
PyTorch數據Pipeline標準化代碼模板
算法工程師的一天
參考文獻
【1】https://pytorch.org/docs/stable/data.html
【2】https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f
【3】https://github.com/utkuozbulak/pytorch-custom-dataset-examples
夕小瑤的賣萌屋
_
關注&星標小夕,帶你解鎖AI秘籍
訂閱號主頁下方「撩一下」有驚喜哦
總結
以上是生活随笔為你收集整理的PyTorch数据Pipeline标准化代码模板的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 剑桥大学终身教授T.S.:7大机器学习算
- 下一篇: Sigmoid函数与Softmax函数的