【深度学习】在PyTorch中构建高效的自定义数据集
文章來源于磐創(chuàng)AI,作者磐創(chuàng)AI
學(xué)習(xí)Dataset類的來龍去脈,使用干凈的代碼結(jié)構(gòu),同時(shí)最大限度地減少在訓(xùn)練期間管理大量數(shù)據(jù)的麻煩。
神經(jīng)網(wǎng)絡(luò)訓(xùn)練在數(shù)據(jù)管理上可能很難做到“大規(guī)模”。
PyTorch 最近已經(jīng)出現(xiàn)在我的圈子里,盡管對Keras和TensorFlow感到滿意,但我還是不得不嘗試一下。令人驚訝的是,我發(fā)現(xiàn)它非常令人耳目一新,非常討人喜歡,尤其是PyTorch 提供了一個(gè)Pythonic API、一個(gè)更為固執(zhí)己見的編程模式和一組很好的內(nèi)置實(shí)用程序函數(shù)。我特別喜歡的一項(xiàng)功能是能夠輕松地創(chuàng)建一個(gè)自定義的Dataset對象,然后可以與內(nèi)置的DataLoader一起在訓(xùn)練模型時(shí)提供數(shù)據(jù)。
在本文中,我將從頭開始研究PyTorchDataset對象,其目的是創(chuàng)建一個(gè)用于處理文本文件的數(shù)據(jù)集,以及探索如何為特定任務(wù)優(yōu)化管道。我們首先通過一個(gè)簡單示例來了解Dataset實(shí)用程序的基礎(chǔ)知識,然后逐步完成實(shí)際任務(wù)。具體地說,我們想創(chuàng)建一個(gè)管道,從The Elder Scrolls(TES)系列中獲取名稱,這些名稱的種族和性別屬性作為一個(gè)one-hot張量。你可以在我的網(wǎng)站(http://syaffers.xyz/#datasets)上找到這個(gè)數(shù)據(jù)集。
Dataset類的基礎(chǔ)知識
Pythorch允許您自由地對“Dataset”類執(zhí)行任何操作,只要您重寫兩個(gè)子類函數(shù):
-返回?cái)?shù)據(jù)集大小的函數(shù),以及
-函數(shù)的函數(shù)從給定索引的數(shù)據(jù)集中返回一個(gè)樣本。
數(shù)據(jù)集的大小有時(shí)可能是灰色區(qū)域,但它等于整個(gè)數(shù)據(jù)集中的樣本數(shù)。因此,如果數(shù)據(jù)集中有10000個(gè)單詞(或數(shù)據(jù)點(diǎn)、圖像、句子等),則函數(shù)“uuLen_uUu”應(yīng)該返回10000個(gè)。
PyTorch使您可以自由地對Dataset類執(zhí)行任何操作,只要您重寫改類中的兩個(gè)函數(shù)即可:
__len__ 函數(shù):返回?cái)?shù)據(jù)集大小
__getitem__ 函數(shù):返回對應(yīng)索引的數(shù)據(jù)集中的樣本
數(shù)據(jù)集的大小有時(shí)難以確定,但它等于整個(gè)數(shù)據(jù)集中的樣本數(shù)量。因此,如果您的數(shù)據(jù)集中有10,000個(gè)樣本(數(shù)據(jù)點(diǎn),圖像,句子等),則__len__函數(shù)應(yīng)返回10,000。
一個(gè)小示例
首先,創(chuàng)建一個(gè)從1到1000所有數(shù)字的Dataset來模擬一個(gè)簡單的數(shù)據(jù)集。我們將其適當(dāng)?shù)孛麨镹umbersDataset。
from?torch.utils.data?import?Datasetclass?NumbersDataset(Dataset):def?__init__(self):self.samples?=?list(range(1,?1001))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):return?self.samples[idx]if?__name__?==?'__main__':dataset?=?NumbersDataset()print(len(dataset))print(dataset[100])print(dataset[122:361])
很簡單,對吧?首先,當(dāng)我們初始化NumbersDataset時(shí),我們立即創(chuàng)建一個(gè)名為samples的列表,該列表將存儲(chǔ)1到1000之間的所有數(shù)字。列表的名稱是任意的,因此請隨意使用您喜歡的名稱。需要重寫的函數(shù)是不用我說明的(我希望!),并且對在構(gòu)造函數(shù)中創(chuàng)建的列表進(jìn)行操作。如果運(yùn)行該python文件,將看到1000、101和122到361之間的值,它們分別指的是數(shù)據(jù)集的長度,數(shù)據(jù)集中索引為100的數(shù)據(jù)以及索引為121到361之間的數(shù)據(jù)集切片。
擴(kuò)展數(shù)據(jù)集
讓我們擴(kuò)展此數(shù)據(jù)集,以便它可以存儲(chǔ)low和high之間的所有整數(shù)。
from?torch.utils.data?import?Datasetclass?NumbersDataset(Dataset):def?__init__(self,?low,?high):self.samples?=?list(range(low,?high))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):return?self.samples[idx]if?__name__?==?'__main__':dataset?=?NumbersDataset(2821,?8295)print(len(dataset))print(dataset[100])print(dataset[122:361])運(yùn)行上面代碼應(yīng)在控制臺(tái)打印5474、2921和2943到3181之間的數(shù)字。通過編寫構(gòu)造函數(shù),我們現(xiàn)在可以將數(shù)據(jù)集的low和high設(shè)置為我們的想要的內(nèi)容。這個(gè)簡單的更改顯示了我們可以從PyTorch的Dataset類獲得的各種好處。例如,我們可以生成多個(gè)不同的數(shù)據(jù)集并使用這些值,而不必像在NumPy中那樣,考慮編寫新的類或創(chuàng)建許多難以理解的矩陣。
從文件讀取數(shù)據(jù)
讓我們來進(jìn)一步擴(kuò)展Dataset類的功能。PyTorch與Python標(biāo)準(zhǔn)庫的接口設(shè)計(jì)得非常優(yōu)美,這意味著您不必?fù)?dān)心集成功能。在這里,我們將
創(chuàng)建一個(gè)全新的使用Python I/O和一些靜態(tài)文件的Dataset類
收集TES角色名稱(我的網(wǎng)站上(http://syaffers.xyz/#datasets)有可用的數(shù)據(jù)集),這些角色名稱分為種族文件夾和性別文件,以填充samples列表
通過在samples列表中存儲(chǔ)一個(gè)元組而不只是名稱本身來跟蹤每個(gè)名稱的種族和性別。
TES名稱數(shù)據(jù)集具有以下目錄結(jié)構(gòu):
. |--?Altmer/ |???|--?Female |???`--?Male |--?Argonian/ |???|--?Female |???`--?Male ...?(truncated?for?brevity)(為了簡潔,這里進(jìn)行省略) `--?Redguard/|--?Female`--?Male每個(gè)文件都包含用換行符分隔的TES名稱,因此我們必須逐行讀取每個(gè)文件,以捕獲每個(gè)種族和性別的所有字符名稱。
import?os from?torch.utils.data?import?Datasetclass?TESNamesDataset(Dataset):def?__init__(self,?data_root):self.samples?=?[]for?race?in?os.listdir(data_root):race_folder?=?os.path.join(data_root,?race)for?gender?in?os.listdir(race_folder):gender_filepath?=?os.path.join(race_folder,?gender)with?open(gender_filepath,?'r')?as?gender_file:for?name?in?gender_file.read().splitlines():self.samples.append((race,?gender,?name))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):return?self.samples[idx]if?__name__?==?'__main__':dataset?=?TESNamesDataset('/home/syafiq/Data/tes-names/')print(len(dataset))print(dataset[420])我們來看一下代碼:首先創(chuàng)建一個(gè)空的samples列表,然后遍歷每個(gè)種族(race)文件夾和性別文件并讀取每個(gè)文件中的名稱來填充該列表。然后將種族,性別和名稱存儲(chǔ)在元組中,并將其添加到samples列表中。運(yùn)行該文件應(yīng)打印19491和('Bosmer', 'Female', 'Gluineth')(每臺(tái)計(jì)算機(jī)的輸出可能不太一樣)。讓我們看一下將數(shù)據(jù)集的一個(gè)batch的樣子:
#?將main函數(shù)改成下面這樣: if?__name__?==?'__main__':dataset?=?TESNamesDataset('/home/syafiq/Data/tes-names/')print(dataset[10:60])
正如您所想的,它的工作原理與列表完全相同。對本節(jié)內(nèi)容進(jìn)行總結(jié),我們剛剛將標(biāo)準(zhǔn)的Python I/O 引入了PyTorch數(shù)據(jù)集中,并且我們不需要任何其他特殊的包裝器或幫助器,只需要單純的Python代碼。實(shí)際上,我們還可以包括NumPy或Pandas之類的其他庫,并且通過一些巧妙的操作,使它們在PyTorch中發(fā)揮良好的作用。讓我們現(xiàn)在來看看在訓(xùn)練時(shí)如何有效地遍歷數(shù)據(jù)集。
用DataLoader加載數(shù)據(jù)
盡管Dataset類是創(chuàng)建數(shù)據(jù)集的一種不錯(cuò)的方法,但似乎在訓(xùn)練時(shí),我們將需要對數(shù)據(jù)集的samples列表進(jìn)行索引或切片。這并不比我們對列表或NumPy矩陣進(jìn)行操作更簡單。PyTorch并沒有沿這條路走,而是提供了另一個(gè)實(shí)用工具類DataLoader。DataLoader充當(dāng)Dataset對象的數(shù)據(jù)饋送器(feeder)。如果您熟悉的話,這個(gè)對象跟Keras中的flow數(shù)據(jù)生成器函數(shù)很類似。DataLoader需要一個(gè)Dataset對象(它延伸任何子類)和其他一些可選參數(shù)(參數(shù)都列在PyTorch的DataLoader文檔(https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader)中)。在這些參數(shù)中,我們可以選擇對數(shù)據(jù)進(jìn)行打亂,確定batch的大小和并行加載數(shù)據(jù)的線程(job)數(shù)量。這是TESNamesDataset在循環(huán)中進(jìn)行調(diào)用的一個(gè)簡單示例。
#?將main函數(shù)改成下面這樣: if?__name__?==?'__main__':from?torch.utils.data?import?DataLoaderdataset?=?TESNamesDataset('/home/syafiq/Data/tes-names/')dataloader?=?DataLoader(dataset,?batch_size=50,?shuffle=True,?num_workers=2)for?i,?batch?in?enumerate(dataloader):print(i,?batch)當(dāng)您看到大量的batch被打印出來時(shí),您可能會(huì)注意到每個(gè)batch都是三元組的列表:第一個(gè)元組包含種族,下一個(gè)元組包含性別,最后一個(gè)元祖包含名稱。
等等,那不是我們之前對數(shù)據(jù)集進(jìn)行切片時(shí)的樣子!這里到底發(fā)生了什么?好吧,事實(shí)證明,DataLoader以系統(tǒng)的方式加載數(shù)據(jù),以便我們垂直而非水平來堆疊數(shù)據(jù)。這對于一個(gè)batch的張量(tensor)流動(dòng)特別有用,因?yàn)閺埩看怪倍询B(即在第一維上)構(gòu)成batch。此外,DataLoader還會(huì)為對數(shù)據(jù)進(jìn)行重新排列,因此在發(fā)送(feed)數(shù)據(jù)時(shí)無需重新排列矩陣或跟蹤索引。
張量(tensor)和其他類型
為了進(jìn)一步探索不同類型的數(shù)據(jù)在DataLoader中是如何加載的,我們將更新我們先前模擬的數(shù)字?jǐn)?shù)據(jù)集,以產(chǎn)生兩對張量數(shù)據(jù):數(shù)據(jù)集中每個(gè)數(shù)字的后4個(gè)數(shù)字的張量,以及加入一些隨機(jī)噪音的張量。為了拋出DataLoader的曲線球,我們還希望返回?cái)?shù)字本身,而不是張量類型,是作為Python字符串返回。__getitem__函數(shù)將在一個(gè)元組中返回三個(gè)異構(gòu)數(shù)據(jù)項(xiàng)。
from?torch.utils.data?import?Dataset import?torchclass?NumbersDataset(Dataset):def?__init__(self,?low,?high):self.samples?=?list(range(low,?high))def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):n?=?self.samples[idx]successors?=?torch.arange(4).float()?+?n?+?1noisy?=?torch.randn(4)?+?successorsreturn?n,?successors,?noisyif?__name__?==?'__main__':from?torch.utils.data?import?DataLoaderdataset?=?NumbersDataset(100,?120)dataloader?=?DataLoader(dataset,?batch_size=10,?shuffle=True)print(next(iter(dataloader)))請注意,我們沒有更改數(shù)據(jù)集的構(gòu)造函數(shù),而是修改了__getitem__函數(shù)。對于PyTorch數(shù)據(jù)集來說,比較好的做法是,因?yàn)樵摂?shù)據(jù)集將隨著樣本越來越多而進(jìn)行縮放,因此我們不想在Dataset對象運(yùn)行時(shí),在內(nèi)存中存儲(chǔ)太多張量類型的數(shù)據(jù)。取而代之的是,當(dāng)我們遍歷樣本列表時(shí),我們將希望它是張量類型,以犧牲一些速度來節(jié)省內(nèi)存。在以下各節(jié)中,我將解釋它的用處。
觀察上面的輸出,盡管我們新的__getitem__函數(shù)返回了一個(gè)巨大的字符串和張量元組,但是DataLoader能夠識別數(shù)據(jù)并進(jìn)行相應(yīng)的堆疊。字符串化后的數(shù)字形成元組,其大小與創(chuàng)建DataLoader時(shí)配置的batch大小的相同。對于兩個(gè)張量,DataLoader將它們垂直堆疊成一個(gè)大小為10x4的張量。這是因?yàn)槲覀儗atch大小配置為10,并且在__getitem__函數(shù)返回兩個(gè)大小為4的張量。
通常來說,DataLoader嘗試將一批一維張量堆疊為二維張量,將一批二維張量堆疊為三維張量,依此類推。在這一點(diǎn)上,我懇請您注意到這對其他機(jī)器學(xué)習(xí)庫中的傳統(tǒng)數(shù)據(jù)處理產(chǎn)生了翻天覆地的影響,以及這個(gè)做法是多么優(yōu)雅。太不可思議了!如果您不同意我的觀點(diǎn),那么至少您現(xiàn)在知道有這樣的一種方法。
完成TES數(shù)據(jù)集的代碼
讓我們回到TES數(shù)據(jù)集。似乎初始化函數(shù)的代碼有點(diǎn)不優(yōu)雅(至少對于我而言,確實(shí)應(yīng)該有一種使代碼看起來更好的方法。請記住我說過的,PyTorch API是像python的(Pythonic)嗎?數(shù)據(jù)集中的工具函數(shù),甚至對內(nèi)部函數(shù)進(jìn)行初始化。為清理TES數(shù)據(jù)集的代碼,我們將更新TESNamesDataset的代碼來實(shí)現(xiàn)以下目的:
更新構(gòu)造函數(shù)以包含字符集
創(chuàng)建一個(gè)內(nèi)部函數(shù)來初始化數(shù)據(jù)集
創(chuàng)建一個(gè)將標(biāo)量轉(zhuǎn)換為獨(dú)熱(one-hot)張量的工具函數(shù)
創(chuàng)建一個(gè)工具函數(shù),該函數(shù)將樣本數(shù)據(jù)轉(zhuǎn)換為種族,性別和名稱的三個(gè)獨(dú)熱(one-hot)張量的集合。
為了使工具函數(shù)正常工作,我們將借助scikit-learn庫對數(shù)值(即種族,性別和名稱數(shù)據(jù))進(jìn)行編碼。具體來說,我們將需要LabelEncoder類。我們對代碼進(jìn)行大量的更新,我將在接下來的幾小節(jié)中解釋這些修改的代碼。
import?os from?sklearn.preprocessing?import?LabelEncoder from?torch.utils.data?import?Dataset import?torchclass?TESNamesDataset(Dataset):def?__init__(self,?data_root,?charset):self.data_root?=?data_rootself.charset?=?charsetself.samples?=?[]self.race_codec?=?LabelEncoder()self.gender_codec?=?LabelEncoder()self.char_codec?=?LabelEncoder()self._init_dataset()def?__len__(self):return?len(self.samples)def?__getitem__(self,?idx):race,?gender,?name?=?self.samples[idx]return?self.one_hot_sample(race,?gender,?name)def?_init_dataset(self):races?=?set()genders?=?set()for?race?in?os.listdir(self.data_root):race_folder?=?os.path.join(self.data_root,?race)races.add(race)for?gender?in?os.listdir(race_folder):gender_filepath?=?os.path.join(race_folder,?gender)genders.add(gender)with?open(gender_filepath,?'r')?as?gender_file:for?name?in?gender_file.read().splitlines():self.samples.append((race,?gender,?name))self.race_codec.fit(list(races))self.gender_codec.fit(list(genders))self.char_codec.fit(list(self.charset))def?to_one_hot(self,?codec,?values):value_idxs?=?codec.transform(values)return?torch.eye(len(codec.classes_))[value_idxs]def?one_hot_sample(self,?race,?gender,?name):t_race?=?self.to_one_hot(self.race_codec,?[race])t_gender?=?self.to_one_hot(self.gender_codec,?[gender])t_name?=?self.to_one_hot(self.char_codec,?list(name))return?t_race,?t_gender,?t_nameif?__name__?==?'__main__':import?stringdata_root?=?'/home/syafiq/Data/tes-names/'charset?=?string.ascii_letters?+?"-'?"dataset?=?TESNamesDataset(data_root,?charset)print(len(dataset))print(dataset[420])修改的構(gòu)造函數(shù)初始化
構(gòu)造函數(shù)這里有很多變化,所以讓我們一點(diǎn)一點(diǎn)地來解釋它。您可能已經(jīng)注意到構(gòu)造函數(shù)中沒有任何文件處理邏輯。我們已將此邏輯移至_init_dataset函數(shù)中,并清理了構(gòu)造函數(shù)。此外,我們添加了一些編碼器,來將原始字符串轉(zhuǎn)換為整數(shù)并返回。samples列表也是一個(gè)空列表,我們將在_init_dataset函數(shù)中填充該列表。構(gòu)造函數(shù)還接受一個(gè)新的參數(shù)charset。顧名思義,它只是一個(gè)字符串,可以將char_codec轉(zhuǎn)換為整數(shù)。
已增強(qiáng)了文件處理功能,該功能可以在我們遍歷文件夾時(shí)捕獲種族和性別的唯一標(biāo)簽。如果您沒有結(jié)構(gòu)良好的數(shù)據(jù)集,這將很有用;例如,如果Argonians擁有一個(gè)與性別無關(guān)的名稱,我們將擁有一個(gè)名為“Unknown”的文件,并將其放入性別集合中,而不管其他種族是否存在“Unknown”性別。所有名稱存儲(chǔ)完畢后,我們將在由種族,性別和名稱構(gòu)成數(shù)據(jù)集來初始化編碼器。
工具函數(shù)
我們添加了兩個(gè)工具函數(shù):to_one_hot和one_hot_sample。to_one_hot使用數(shù)據(jù)集的內(nèi)部編碼器將數(shù)值列表轉(zhuǎn)換為整數(shù)列表,然后再調(diào)用看似不適當(dāng)?shù)膖orch.eye函數(shù)。實(shí)際上,這是一種巧妙的技巧,可以將整數(shù)列表快速轉(zhuǎn)換為一個(gè)向量。torch.eye函數(shù)創(chuàng)建一個(gè)任意大小的單位矩陣,其對角線上的值為1。如果對矩陣行進(jìn)行索引,則將在該索引處獲得值為1的行向量,這是獨(dú)熱向量的定義!
因?yàn)槲覀冃枰獙⑷齻€(gè)數(shù)據(jù)轉(zhuǎn)換為張量,所以我們將在對應(yīng)數(shù)據(jù)的每個(gè)編碼器上調(diào)用to_one_hot函數(shù)。one_hot_sample將單個(gè)樣本數(shù)據(jù)轉(zhuǎn)換為張量元組。種族和性別被轉(zhuǎn)換為二維張量,這實(shí)際上是擴(kuò)展的行向量。該向量也被轉(zhuǎn)換為二維張量,但該二維向量包含該名稱的每個(gè)字符每個(gè)獨(dú)熱向量。
__getitem__調(diào)用
最后,__getitem__函數(shù)的代碼已更新為僅在one_hot_sample給定種族,性別和名稱的情況下調(diào)用該函數(shù)。注意,我們不需要在samples列表中預(yù)先準(zhǔn)備張量,而是僅在調(diào)用__getitem__函數(shù)(即DataLoader加載數(shù)據(jù)流時(shí))時(shí)形成張量。當(dāng)您在訓(xùn)練期間有成千上萬的樣本要加載時(shí),這使數(shù)據(jù)集具有很好的可伸縮性。
您可以想象如何在計(jì)算機(jī)視覺訓(xùn)練場景中使用該數(shù)據(jù)集。數(shù)據(jù)集將具有文件名列表和圖像目錄的路徑,從而讓__getitem__函數(shù)僅讀取圖像文件并將它們及時(shí)轉(zhuǎn)換為張量來進(jìn)行訓(xùn)練。通過提供適當(dāng)數(shù)量的工作線程,DataLoader可以并行處理多個(gè)圖像文件,可以使其運(yùn)行得更快。PyTorch數(shù)據(jù)加載教程(https://pytorch.org/tutorials/beginner/data_loading_tutorial.html)有更詳細(xì)的圖像數(shù)據(jù)集,加載器,和互補(bǔ)數(shù)據(jù)集。這些都是由torchvision庫進(jìn)行封裝的(它經(jīng)常隨著PyTorch一起安裝)。torchvision用于計(jì)算機(jī)視覺,使得圖像處理管道(例如增白,歸一化,隨機(jī)移位等)很容易構(gòu)建。
回到原文。數(shù)據(jù)集已經(jīng)構(gòu)建好了,看來我們已準(zhǔn)備好使用它進(jìn)行訓(xùn)練……
……但我們還沒有
如果我們嘗試使用DataLoader來加載batch大小大于1的數(shù)據(jù),則會(huì)遇到錯(cuò)誤:
您可能已經(jīng)看到過這種情況,但現(xiàn)實(shí)是,文本數(shù)據(jù)的不同樣本之間很少有相同的長度。結(jié)果,DataLoader嘗試批量處理多個(gè)不同長度的名稱張量,這在張量格式中是不可能的,因?yàn)樵贜umPy數(shù)組中也是如此。為了說明此問題,請考慮以下情況:當(dāng)我們將“ John”和“ Steven”之類的名稱堆疊在一起形成一個(gè)單一的獨(dú)熱矩陣時(shí)。'John'轉(zhuǎn)換為大小4xC的二維張量,'Steven'轉(zhuǎn)換為大小6xC二維張量,其中C是字符集的長度。DataLoader嘗試將這些名稱堆疊為大小2x?xC三維張量(DataLoader認(rèn)為堆積大小為1x4xC和1x6xC)。由于第二維不匹配,DataLoader拋出錯(cuò)誤,導(dǎo)致它無法繼續(xù)運(yùn)行。
可能的解決方案
為了解決這個(gè)問題,這里有兩種方法,每種方法都各有利弊。
將批處理(batch)大小設(shè)置為1,這樣您就永遠(yuǎn)不會(huì)遇到錯(cuò)誤。如果批處理大小為1,則單個(gè)張量不會(huì)與(可能)不同長度的其他任何張量堆疊在一起。但是,這種方法在進(jìn)行訓(xùn)練時(shí)會(huì)受到影響,因?yàn)樯窠?jīng)網(wǎng)絡(luò)在單批次(batch)的梯度下降時(shí)收斂將非常慢。另一方面,當(dāng)批次大小不重要時(shí),這對于快速測試時(shí),數(shù)據(jù)加載或沙盒測試很有用。
通過使用空字符填充或截?cái)嗝Q來獲得固定的長度。截短長的名稱或用空字符來填充短的名稱可以使所有名稱格式正確,并具有相同的輸出張量大小,從而可以進(jìn)行批處理。不利的一面是,根據(jù)任務(wù)的不同,空字符可能是有害的,因?yàn)樗荒艽碓紨?shù)據(jù)。
由于本文的目的,我將選擇第二個(gè)方法,您只需對整體數(shù)據(jù)管道進(jìn)行很少的更改即可實(shí)現(xiàn)此目的。請注意,這也適用于任何長度不同的字符數(shù)據(jù)(盡管有多種填充數(shù)據(jù)的方法,請參見NumPy(https://docs.scipy.org/doc/numpy-1.15.0/reference/generated/numpy.pad.html)和PyTorch(https://pytorch.org/docs/stable/_modules/torch/nn/modules/padding.html)中的選項(xiàng)部分)。在我的例子中,我選擇用零來填充名稱,因此我更新了構(gòu)造函數(shù)和_init_dataset函數(shù):
...?def?__init__(self,?data_root,?charset,?length):self.data_root?=?data_rootself.charset?=?charset?+?'\0'self.length?=?length...with?open(gender_filepath,?'r')?as?gender_file:for?name?in?gender_file.read().splitlines():if?len(name)?<?self.length:name?+=?'\0'?*?(self.length?-?len(name))else:name?=?name[:self.length-1]?+?'\0'self.samples.append((race,?gender,?name))...首先,我在構(gòu)造函數(shù)引入一個(gè)新的參數(shù),該參數(shù)將所有傳入名稱字符固定為length值。我還將\0字符添加到字符集中,用于填充短的名稱。接下來,數(shù)據(jù)集初始化邏輯已更新。缺少長度的名稱僅用\0填充,直到滿足長度的要求為止。超過固定長度的名稱將被截?cái)?#xff0c;最后一個(gè)字符將被替換為\0。替換是可選的,這取決于具體的任務(wù)。
而且,如果您現(xiàn)在嘗試加載此數(shù)據(jù)集,您應(yīng)該獲得跟您當(dāng)初所期望的數(shù)據(jù):正確的批(batch)大小格式的張量。下圖顯示了批大小為2的張量,但請注意有三個(gè)張量:
堆疊種族張量,獨(dú)熱編碼形式表示該張量是十個(gè)種族中的某一個(gè)種族
堆疊性別張量,獨(dú)熱編碼形式表示數(shù)據(jù)集中存在兩種性別中的某一種性別
堆疊名稱張量,最后一個(gè)維度應(yīng)該是charset的長度,第二個(gè)維度是名稱長度(固定大小后),第一個(gè)維度是批(batch)大小。
數(shù)據(jù)拆分實(shí)用程序
所有這些功能都內(nèi)置在PyTorch中,真是太棒了。現(xiàn)在可能出現(xiàn)的問題是,如何制作驗(yàn)證甚至測試集,以及如何在不擾亂代碼庫并盡可能保持DRY的情況下執(zhí)行驗(yàn)證或測試。測試集的一種方法是為訓(xùn)練數(shù)據(jù)和測試數(shù)據(jù)提供不同的data_root,并在運(yùn)行時(shí)保留兩個(gè)數(shù)據(jù)集變量(另外還有兩個(gè)數(shù)據(jù)加載器),尤其是在訓(xùn)練后立即進(jìn)行測試的情況下。
如果您想從訓(xùn)練集中創(chuàng)建驗(yàn)證集,那么可以使用PyTorch數(shù)據(jù)實(shí)用程序中的random_split 函數(shù)輕松處理這一問題。random_split 函數(shù)接受一個(gè)數(shù)據(jù)集和一個(gè)劃分子集大小的列表,該函數(shù)隨機(jī)拆分?jǐn)?shù)據(jù),以生成更小的Dataset對象,這些對象可立即與DataLoader一起使用。這里有一個(gè)例子。
通過使用內(nèi)置函數(shù)輕松拆分自定義PyTorch數(shù)據(jù)集來創(chuàng)建驗(yàn)證集。
事實(shí)上,您可以在任意間隔進(jìn)行拆分,這對于折疊交叉驗(yàn)證集非常有用。我對這個(gè)方法唯一的不滿是你不能定義百分比分割,這很煩人。至少子數(shù)據(jù)集的大小從一開始就明確定義了。另外,請注意,每個(gè)數(shù)據(jù)集都需要單獨(dú)的DataLoader,這絕對比在循環(huán)中管理兩個(gè)隨機(jī)排序的數(shù)據(jù)集和索引更干凈。
結(jié)束語
希望本文能使您了解PyTorch中Dataset和DataLoader實(shí)用程序的功能。與干凈的Pythonic API結(jié)合使用,它可以使編碼變得更加輕松愉快,同時(shí)提供一種有效的數(shù)據(jù)處理方式。我認(rèn)為PyTorch開發(fā)的易用性根深蒂固于他們的開發(fā)理念,并且在我的工作中使用PyTorch之后,我從此不再回頭使用Keras和TensorFlow。我不得不說我確實(shí)錯(cuò)過了Keras模型隨附的進(jìn)度條和fit /predict API,但這是一個(gè)小小的挫折,因?yàn)樽钚碌膸ensorBoard接口的PyTorch帶回了熟悉的工作環(huán)境。盡管如此,目前,PyTorch是我將來的深度學(xué)習(xí)項(xiàng)目的首選。
我鼓勵(lì)以這種方式構(gòu)建自己的數(shù)據(jù)集,因?yàn)樗宋乙郧肮芾頂?shù)據(jù)時(shí)遇到的許多凌亂的編程習(xí)慣。在復(fù)雜情況下,Dataset 是一個(gè)救命稻草。我記得必須管理屬于一個(gè)樣本的數(shù)據(jù),但該數(shù)據(jù)必須來自三個(gè)不同的MATLAB矩陣文件,并且需要正確切片,規(guī)范化和轉(zhuǎn)置。如果沒有Dataset和DataLoader組合,我不知如何進(jìn)行管理,特別是因?yàn)閿?shù)據(jù)量巨大,而且沒有簡便的方法將所有數(shù)據(jù)組合成NumPy矩陣且不會(huì)導(dǎo)致計(jì)算機(jī)崩潰。
最后,查看PyTorch數(shù)據(jù)實(shí)用程序文檔頁面(https://pytorch.org/docs/stable/data.html) ,其中包含其他類別和功能,這是一個(gè)很小但有價(jià)值的實(shí)用程序庫。您可以在我的GitHub上找到TES數(shù)據(jù)集的代碼,在該代碼中,我創(chuàng)建了與數(shù)據(jù)集同步的PyTorch中的LSTM名稱預(yù)測變量(https://github.com/syaffers/tes-names-rnn)。讓我知道這篇文章是有用的還是不清楚的,以及您將來是否希望獲得更多此類內(nèi)容。
原文鏈接:https://towardsdatascience.com/building-efficient-custom-datasets-in-pytorch-2563b946fd9f
- End -
往期精彩回顧適合初學(xué)者入門人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊深度學(xué)習(xí)筆記專輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專輯獲取一折本站知識星球優(yōu)惠券,復(fù)制鏈接直接打開:https://t.zsxq.com/yFQV7am本站qq群1003271085。加入微信群請掃碼進(jìn)群:總結(jié)
以上是生活随笔為你收集整理的【深度学习】在PyTorch中构建高效的自定义数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【深度学习】常见优化器的PyTorch实
- 下一篇: 【论文解读】让特征感受野更灵活,腾讯优图