深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强
@Author:Runsen
上次對(duì)xml文件進(jìn)行提取,使用到一個(gè)Albumentation模塊。Albumentation模塊是一個(gè)數(shù)據(jù)增強(qiáng)的工具,目標(biāo)檢測(cè)圖像預(yù)處理通過(guò)使用“albumentation”來(lái)應(yīng)用的,這是一個(gè)易于與PyTorch數(shù)據(jù)轉(zhuǎn)換集成的python庫(kù)。
Albumentation 是一種工具,可以在將(圖像/圖片)插入模型之前自定義 處理(彈性、網(wǎng)格、運(yùn)動(dòng)模糊、移位、縮放、旋轉(zhuǎn)、轉(zhuǎn)置、對(duì)比度、亮度等])到圖像/圖片。
對(duì)此,Albumentation 官方文檔:
- https://albumentations.ai/
為什么要看看這個(gè)東西?因?yàn)閷?Torchvision 代碼重構(gòu)為 Albumentation 的效果最好,運(yùn)行更快。
上圖是使用 Intel Xeon Platinum 8168 CPU 在 ImageNet中通過(guò) 2000 個(gè)驗(yàn)證集圖像的測(cè)試結(jié)果。每個(gè)單元格中的值表示在單個(gè)核心中處理的圖像數(shù)量。可以看到 Albumentation在許多轉(zhuǎn)換方面比所有其他庫(kù)至少高出 2 倍。
Albumentation Github 的官方 CPU 基準(zhǔn)測(cè)試https://github.com/albumentations-team/albumentations
下面,我導(dǎo)入了下面的模塊:
from PIL import Image import time import torch import torchvision from torch.utils.data import Dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2 import numpy as np為了演示的目的,我找了一張前幾天畢業(yè)回校拍的照片
原始 TorchVision 數(shù)據(jù)管道
創(chuàng)建一個(gè) Dataloader 來(lái)使用 PyTorch 和 Torchvision 處理圖像數(shù)據(jù)管道。
- 創(chuàng)建一個(gè)簡(jiǎn)單的 Pytorch 數(shù)據(jù)集類
- 調(diào)用圖像并進(jìn)行轉(zhuǎn)換
- 用 100 個(gè)循環(huán)測(cè)量整個(gè)處理時(shí)間
首先,從torch.utils.data獲取 Dataset抽象類,并創(chuàng)建一個(gè) TorchVision數(shù)據(jù)集類。然后我插入圖像并使用__getitem__方法進(jìn)行轉(zhuǎn)換。另外,我用來(lái)total_time = (time.time() - start_t測(cè)量需要多長(zhǎng)時(shí)間
class TorchvisionDataset(Dataset):def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]image = Image.open(file_path)start_t = time.time()if self.transform:image = self.transform(image)total_time = (time.time() - start_t)return image, label, total_time然后將圖像大小調(diào)整為 256x256(高度 * 重量)并隨機(jī)裁剪到 224x224 大小。然后以 50% 的概率應(yīng)用水平翻轉(zhuǎn)并將其轉(zhuǎn)換為張量。輸入文件路徑應(yīng)該是您的圖像所在的 Google Drive 的路徑。
torchvision_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(), ])torchvision_dataset = TorchvisionDataset(file_paths=["demo.jpg"],labels=[1],transform=torchvision_transform, )下面計(jì)算從 torchvision_dataset 中提取樣本圖像并對(duì)其進(jìn)行轉(zhuǎn)換所花費(fèi)的時(shí)間,然后運(yùn)行 ??100 次循環(huán)以檢查它所花費(fèi)的平均毫秒。
torchvision time/sample: 7.31137752532959 ms在torch中的GPU,原始 TorchVision 數(shù)據(jù)管道數(shù)據(jù)預(yù)處理的速度大約是0.0731137752532959 ms。最后輸出的圖像則為 224x224而且發(fā)生了翻轉(zhuǎn)!
Albumentation 數(shù)據(jù)管道
現(xiàn)在創(chuàng)建了一個(gè) Albumentations Dataset 類,具體的transform和原始 TorchVision 數(shù)據(jù)管道完全一樣。
from PIL import Image import time import torch import torchvision from torch.utils.data import Dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2 import numpy as npclass AlbumentationsDataset(Dataset):"""__init__ and __len__ functions are the same as in TorchvisionDataset"""def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]# Read an image with OpenCVimage = cv2.imread(file_path)# By default OpenCV uses BGR color space for color images,# so we need to convert the image to RGB color space.image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)start_t = time.time()if self.transform:augmented = self.transform(image=image)image = augmented['image']total_time = (time.time() - start_t)return image, label, total_timealbumentations_transform = albumentations.Compose([albumentations.Resize(256, 256),albumentations.RandomCrop(224, 224),albumentations.HorizontalFlip(), # Same with transforms.RandomHorizontalFlip()albumentations.pytorch.transforms.ToTensor() ]) albumentations_dataset = AlbumentationsDataset(file_paths=["demo.jpg"],labels=[1],transform=albumentations_transform, )total_time = 0 for i in range(100):sample, _, transform_time = albumentations_dataset[0]total_time += transform_timeprint("albumentations time/sample: {} ms".format(total_time*10))plt.figure(figsize=(10, 10)) plt.imshow(transforms.ToPILImage()(sample)) plt.show()具體輸出如下:
albumentations time/sample: 0.5056881904602051 ms在torch中的GPU,Albumentation 數(shù)據(jù)管道 數(shù)據(jù)管道數(shù)據(jù)預(yù)處理的速度大約是0.005056881904602051 ms。
因此,在真正的工業(yè)落地,基本需要將原始 TorchVision 數(shù)據(jù)管道改寫成Albumentation 數(shù)據(jù)管道,因?yàn)槁涞仨?xiàng)目的速度很重要。
Albumentation數(shù)據(jù)增強(qiáng)
最后,我將展示如何使用albumentations中OneOf函數(shù)進(jìn)行書(shū)增強(qiáng),我個(gè)人覺(jué)得這個(gè)函數(shù)在 Albumentation 中非常有用。
from PIL import Image import time import torch import torchvision from torch.utils.data import Dataset from torchvision import transforms import albumentations import albumentations.pytorch from matplotlib import pyplot as plt import cv2class AlbumentationsDataset(Dataset):"""__init__ and __len__ functions are the same as in TorchvisionDataset"""def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __len__(self):return len(self.file_paths)def __getitem__(self, idx):label = self.labels[idx]file_path = self.file_paths[idx]image = cv2.imread(file_path)image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)if self.transform:augmented = self.transform(image=image)image = augmented['image']return image, label# OneOf隨機(jī)采用括號(hào)內(nèi)列出的變換之一。 # 我們甚至可以將發(fā)生的概率放在函數(shù)本身中。例如,如果 ([…], p=0.5) 之一,它會(huì)以 50% 的機(jī)會(huì)跳過(guò)整個(gè)變換,并以 1/6 的機(jī)會(huì)隨機(jī)選擇三個(gè)變換之一。 albumentations_transform_oneof = albumentations.Compose([albumentations.Resize(256, 256),albumentations.RandomCrop(224, 224),albumentations.OneOf([albumentations.HorizontalFlip(p=1),albumentations.RandomRotate90(p=1),albumentations.VerticalFlip(p=1)], p=1),albumentations.OneOf([albumentations.MotionBlur(p=1),albumentations.OpticalDistortion(p=1), albumentations.GaussNoise(p=1)], p=1),albumentations.pytorch.ToTensor() ])albumentations_dataset = AlbumentationsDataset(file_paths=["demo.jpg"],labels=[1],transform=albumentations_transform_oneof, )num_samples = 5 fig, ax = plt.subplots(1, num_samples, figsize=(25, 5)) for i in range(num_samples):ax[i].imshow(transforms.ToPILImage()(albumentations_dataset[0][0]))ax[i].axis('off')plt.show()
上面的OneOf是在水平翻轉(zhuǎn)、旋轉(zhuǎn)、垂直翻轉(zhuǎn)中隨機(jī)選擇,在模糊、失真、噪聲中隨機(jī)選擇。所以在這種情況下,我們?cè)试S 3x3 = 9 種組合
總結(jié)
以上是生活随笔為你收集整理的深度学习和目标检测系列教程 9-300:TorchVision和Albumentation性能对比,如何使用Albumentation对图片数据做数据增强的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 抵押车供求关系
- 下一篇: 大众途观消音器哒哒声音