Pytorch(3)-数据载入接口:Dataloader、datasets
pytorch數據載入
- 1.數據載入概況
- Dataloader 是啥
- 2.支持的三類數據集
- 2.1 torchvision.datasets.xxx
- 2.2 torchvision.datasets.ImageFolder
- 2.3 寫自己的數據類,讀入定制化數據
- 2.3.1 數據類的編寫
- map-style范式
- iterable-style 范式
- 2.3.2 DataLoader 導入數據類
1.數據載入概況
數據是機器學習算法的驅動力, Pytorch提供了方便的數據載入和處理接口. 數據載入流程為:
step1: 指定要使用的數據集dataset
step2: 使用Dataloader載入數據
dataloader實質是一個可迭代對象,不能使用next()訪問。但如果使用iter()封裝,返回一個迭代器,可以使用.next()操作。
Dataloader 是啥
來自官網document的描述:
Dataloader. Combines a dataset and a sampler, and provides an iterable over the given dataset. The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. See torch.utils.data documentation page for more details.大概就是說:用來對數據集進行(小批次)迭代 載入的接口,所能夠載入的數據集要么支持map-style操作,要么支持 iterable-style操作。
(這兩種操作只有在編寫用戶數據類時才需要考慮,使用內置公開數據集和.ImageFolder不需要管這兩者是啥東西,開發者已經幫你寫好了)
2.支持的三類數據集
1.torchvision.datasets–內置了許多常見的公開數據集
2.torchvision.datasets.ImageFolder–用戶定制數據集1(只要自己的數據集滿足ImageFolder要求的格式,提供數據集所在的地址即可)
3.定制數據集–需要編寫自己的dataset 類
2.1 torchvision.datasets.xxx
一些常用的公開數據集合,可以在torchvision.datasets接口中找到。
例如–MNIST、Fashion-MNIST、KMNIST、EMNIST、FakeData、COCO、Captions、Detection、LSUN、ImageFolder、DatasetFolder、ImageNet、CIFAR、STL10、SVHN、PhotoTour、SBU、Flickr、VOC、Cityscapes、SBD等常用數據集合。
torchvision.datasets在使用一個新的數據集合前,需要保證本地擁有該數據集合(符合pytorch內部編碼格式)。最簡單額方式是第一次使用時,將download=True將默認將該數據集下載到指定的root 目錄中。
CIFAR10數據集使用的例子
transform = transforms.Compose( [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2) 默認值:train=Truestep1 數據集選擇與圖片處理方式選擇
trainset = torchvision.datasets.CIFAR10(root=’./data’, train=True,download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root=’./data’, train=False,download=True, transform=transform)
參數解釋:
1.root=’./data’
數據集的保存目錄,各種數據集有自己的文件格式,其中MNIST是以training.pt和test.pt的保存圖像數據信息(具體看一下文件應該怎么存,讀入之后的列表和迭代器各是什么內容)
2.train =True
處理MNIST時從training.pt讀取訓練數據,=False 從test.pt讀取測試數據。仔細觀察,上面兩句話只有在train這個選項處不同.
3.download =True
會從網上下載對應的數據集文件,MNIST對應.pt文件,如果存在 .pt 文件,這個參數可以設置為False
4.transform
設置一組對圖像進行處理的操作,這一組操作由Compose組成,這一組compose 的順序還很重要按如下順序編寫:
transforms.Resize()
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
step2 數據載入接口
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
參數解釋
1.將剛剛生成的trainset列表傳入 torch.utils.data.DataLoader()
2.batch_size=4 設定圖像數據批次大小
3.shuffle=True 每一個epoch過程中會打亂數據順序,重新隨機選擇
4.導入數據時的線程數目,默認為0,主線程導入數據
2.2 torchvision.datasets.ImageFolder
當數據集超出1中所提供數據集的范圍時,Pytorch還提供了ImageFolder數據集導入方式。只要將數據按照一定的要求存放,就能如方式1一樣方便取用。
數據集合格式要求:同類別的圖像放在一個文件夾下,用類別名稱/標號來命名文件夾。要自己手工設計訓練集合、測試集合。
x=torch.datasets.ImageFolder(root="圖像集合中文件夾路徑”)
x是一個ImageFolder格式的數據:
其中重要主要成員為:
class_to_idx ={dict} 是字典數據,以“文件夾名字:分配類別序號”作為鍵值的字典
classes ={list} 包含所有文件夾名字的一個序列
imgs={list} 列表元素為–(圖像路徑,對應文件夾名)
使用torch.utils.data.DataLoader載入數據:
trainloader = torch.utils.data.DataLoader(x, batch_size=4, shuffle=True, num_workers=4)
參考網址:
https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
2.3 寫自己的數據類,讀入定制化數據
當用戶數據個格式不能用以上兩種方式讀取時,可以嘗試寫自己的數據類
所有的datasets都是torch.utils.data.Dataset的子類,方法1中使用的是torchvision.datasets.數據集,方法 2中使用的是torchvision.datasets.ImageFolder。當我們在編寫自己的數據類時,也需要繼承Dataset類。
2.3.1 數據類的編寫
在介紹Dataloader 使提到過,其載入的數據類需要滿足兩者操作中的一個(map-style操作/iterable-style操作)
map-style范式
Map-style 操作范式數據類的核心:實現了 getitem() 和 len()方法,通過data[index]獲取數據樣本和相應的標簽。
猜測:DataLoader 在導入minibatch數據時,隨機采樣一批index(通過len確認index 的采樣范圍), 然后在經過getitem獲取相應的數據
class MyDataset:def __init__(self, gentor: object, batchSize: int, imgSize: int):# 從源數據中讀取數據列表,或者能操作數據的名稱列表def __len__(self):# 返回數據集樣本的數量return sample_map_numdef __getitem__(self, idx:int):# 通過idx獲取數據datadata = get(idx) // get 依據不同的數據集定制// 進行一些tansform操作在返回return data官方實踐demo:https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
iterable-style 范式
Iterable-style 操作范式數據類 是 IterableDataset的子類,實現了__iter__()方法。當隨機讀取非常耗時/無法實現時。(數據流,實時記錄的數據)
有機會實踐一下
2.3.2 DataLoader 導入數據類
編寫好了自己的數據類之后,同其他兩種數據類一樣使用DataLoader導入數據即可。
train_set = MyDataset()data = train_set[0] # idx 讀取某一個數據trainloader = DataLoader(train_set, batch_size=64, shuffle=True) # 封裝成dataloader的形式print(len(trainloader))for _, data in enumerate(trainloader):....下面提供一些可供參考的博文:
https://www.jianshu.com/p/220357ca3342
https://www.cnblogs.com/devilmaycry812839668/p/10122148.html
https://ptorch.com/news/215.html
總結
以上是生活随笔為你收集整理的Pytorch(3)-数据载入接口:Dataloader、datasets的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: FM,FFM及其实现
- 下一篇: CSDN写博客(字体颜色、大小)