timm 笔记:数据集
1 數(shù)據(jù)集channel數(shù)問題
1.1 torchvision的不足
????????ImageNet數(shù)據(jù)由3通道RGB圖像組成。因此,為了能夠在大多數(shù)庫中使用預(yù)先訓(xùn)練的權(quán)值,模型期望一個(gè)3通道的輸入圖像。
? ? ? ? 比如對于resnet34,如果我們使用1個(gè)channel的輸入的話:
import torch import torchvisionm = torchvision.models.resnet34(pretrained=True)x = torch.randn(1, 1, 224, 224)try: m(x).shape except Exception as e: print(e) ''' Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 224, 224] to have 3 channels, but got 1 channels instead '''????????是會報(bào)錯的
? ? ? ? 此時(shí)的一種方法是將1維的channel復(fù)制兩次,成為三維的channel
import torch import torchvisionm = torchvision.models.resnet34(pretrained=True)x = torch.randn(1, 1, 224, 224)x=torch.cat((x,x,x),1)# 新增了這一行try: print(m(x).shape) except Exception as e: print(e) #torch.Size([1, 1000])????????然而,如果維度比3多的話,可能就沒有辦法刪去某個(gè)維度,然后使用預(yù)訓(xùn)練模型。它們可以做的只是隨機(jī)初始化權(quán)重,自己訓(xùn)練。
1.2 timm的解決方法
輸入channel是1或者25都o(jì)k了
import timmm = timm.create_model('resnet34', pretrained=True, in_chans=1)x = torch.randn(1, 1, 224, 224)m(x).shape#torch.Size([1, 1000]) m = timm.create_model('resnet34', pretrained=True, in_chans=25)# 25-channel image x = torch.randn(1, 25, 224, 224)m(x).shape #torch.Size([1, 1000])2 數(shù)據(jù)集Dataset
timm數(shù)據(jù)庫中,有三種主要的數(shù)據(jù)集類:
2.1 ImageDataset
? 與torchvision.datasets.ImageFolder?類似,ImageDataset的作用是創(chuàng)建訓(xùn)練集和驗(yàn)證集
2.1.1 解析器 parser
? ? ? ? 通過使用create_parser函數(shù),我們可以自動設(shè)置解析器
? ? ? ? 解析器找到所有root路徑上的圖片和目標(biāo)
? ? ? ? root路徑結(jié)構(gòu)如下所示?
? ?解析器創(chuàng)建一個(gè)class_to_idx字典:
? ?
? ? ? ? 同時(shí)有一個(gè)叫samples的元組列表:
????????
? ? ? ? 解析器是可以下標(biāo)訪問的,?parser[index]將返回一個(gè)self.samples中標(biāo)簽是index的樣本(比如parser[0],會返回一個(gè)('root/dog/xxx.png', 0)
?2.1.2?__getitem__(index: int) → Tuple[Any, Any]
一旦解析器創(chuàng)建完畢,那我們可以用以下方式獲得圖片和標(biāo)簽
img, target = self.parser[index]然后將圖像識別成PIL.Image,然后轉(zhuǎn)換成RGB圖像,還是讀取成二進(jìn)制,這取決于load_bytes語句
? ? ? ? 如果圖片沒有target,那么我們將target設(shè)置為-1
2.1.3 使用場景
????????ImageDataset也可以作為torchvision.datasets.ImageFolder的一個(gè)代替
? ? ? ? 假設(shè)我們有imagenette2-320數(shù)據(jù)集,他的文件架構(gòu)如下所示
數(shù)據(jù)集來源:
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz?每一個(gè) n****都是一個(gè)文件夾,里面是屬于這個(gè)類的JPEG文件
創(chuàng)建 ImageDataset:
from timm.data.dataset import ImageDatasetdataset = ImageDataset('./imagenette2-320') dataset[0] #(<PIL.Image.Image image mode=RGB size=426x320 at 0x22E890BF5C8>, 0)dataset.parser
from timm.data.dataset import ImageDatasetdataset = ImageDataset('./imagenette2-320') dataset.parser #<timm.data.parsers.parser_image_folder.ParserImageFolder at 0x22e83cd0688>?class_to_idx
from timm.data.dataset import ImageDatasetdataset = ImageDataset('./imagenette2-320') dataset.parser.class_to_idx ''' {'n01440764': 0,'n02102040': 1,'n02979186': 2,'n03000684': 3,'n03028079': 4,'n03394916': 5,'n03417042': 6,'n03425413': 7,'n03445777': 8,'n03888257': 9} '''?paser的sample
from timm.data.dataset import ImageDatasetdataset = ImageDataset('./imagenette2-320') dataset.parser.samples[:5] ''' [('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00000293.JPEG', 0),('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00002138.JPEG', 0),('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00003014.JPEG', 0),('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00006697.JPEG', 0),('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00007197.JPEG', 0)] '''可視化一張數(shù)據(jù)的圖片
import matplotlib.pyplot as plt # plt 用于顯示圖片 import matplotlib.image as mpimg # mpimg 用于讀取圖片 import numpy as nplena = mpimg.imread(dataset.parser.samples[0][0]) # 讀取和代碼處于同一目錄下的 lena.png # 此時(shí) lena 就已經(jīng)是一個(gè) np.array 了,可以對它進(jìn)行任意處理 lena.shape #(512, 512, 3)plt.imshow(lena) # 顯示圖片 plt.axis('off') # 不顯示坐標(biāo)軸 plt.show()?
2.2??IterableImageDataset
????????和pytorch的?IterableDataset?類似,timm提供了?IterableImageDataset。
??和ImageDataset相似,IterableImageDataset首先創(chuàng)建一個(gè)解析器,他也基于根目錄創(chuàng)建一組樣本。
????????和ImageDataset相似,解析器也返回一組圖像,圖像的target也是圖像所在的文件夾名稱
? ?***但有一點(diǎn)需要注意,IterableImageDataset并沒有__getitem__方法,因此他不可以用下標(biāo)訪問。dataset[0]會報(bào)錯
2.2.1 __iter__
? ? ? ?從IterableImageDataset的解析器中得到圖片和對應(yīng)的標(biāo)簽
from timm.data import IterableImageDataset from timm.data.parsers.parser_image_folder import ParserImageFolder from timm.data.transforms_factory import create_transform root = './imagenette2-320/' parser = ParserImageFolder(root) iterable_dataset = IterableImageDataset(root=root, parser=parser) parser[0] # (<_io.BufferedReader name='./imagenette2-320/train\\n01440764\\ILSVRC2012_val_00000293.JPEG'>,0) next(iter(iterable_dataset)) # (<_io.BufferedReader name='./imagenette2-320/train\\n01440764\\ILSVRC2012_val_00000293.JPEG'>,0)2.3 AugmixDataset
????????augmix 是一種數(shù)據(jù)增強(qiáng)的方法
class AugmixDataset(dataset: ImageDataset, num_splits: int = 2)????????最后的返回結(jié)果是 original data 和num_splits-1 輪的增強(qiáng)數(shù)據(jù)(每一輪增強(qiáng)數(shù)據(jù)都是原始數(shù)據(jù)的基礎(chǔ)上獲得的)
2.3.1? ?__getitem__(index: int) -> Tuple[Any, Any]
2.3.2 使用方法
這個(gè)需要GPU,所以我在服務(wù)器上跑的
>>> from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader >>> >>> dataset = ImageDataset('./imagenette2-320/') >>> dataset = AugMixDataset(dataset, num_splits=2) >>> loader_train = create_loader( ... dataset, ... input_size=(3, 224, 224), ... batch_size=8, ... is_training=True, ... scale=[0.08, 1.], ... ratio=[0.75, 1.33], ... num_aug_splits=2 ... ) >>> next(iter(loader_train))[0].shapetorch.Size([16, 3, 224, 224])注意看這里,我們的batch_size是8,返回的是16維,因?yàn)閛riginal是8,這里augmix又是8維
3 DataLoader
timm的 Dataloader比`torch.utils.data.DataLoader`快,且略有不同
創(chuàng)建timm的dataloader的最基本的方法就是調(diào)用timm.data.loader中的create_loader。它需要一個(gè)dataset對象,一個(gè)input_size和一個(gè)batch_size
3.1 創(chuàng)建dataset
?創(chuàng)建 ImageDataset:
from timm.data.dataset import ImageDatasetdataset = ImageDataset('./imagenette2-320') dataset[0] #(<PIL.Image.Image image mode=RGB size=426x320 at 0x22E890BF5C8>, 0)3.2 創(chuàng)建DataLoader
from timm.data.loader import create_loadertry:# only works if gpu present on machinetrain_loader = create_loader(dataset, (3, 224, 224), 4) except:train_loader = create_loader(dataset, (3, 224, 224), 4, use_prefetcher=False)那么,這里為什么要用異常處理語句呢??
3.2.1 Prefetch loader
????????timm 有一個(gè)類PrefetchLoader。我們默認(rèn)用這個(gè)DataLoader來創(chuàng)建我們的DataLoader。但是它只工作在GPU上。
? ? ? ? 我本地的train_loader:
<torch.utils.data.dataloader.DataLoader at 0x22e834c3548>? ? ? ? 服務(wù)器(有GPU)的train_loader:
<timm.data.loader.PrefetchLoader object at 0x7f65acf9cef0>
?
總結(jié)
以上是生活随笔為你收集整理的timm 笔记:数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: TImm 笔记: 训练模型
- 下一篇: 错误处理:RuntimeError: I