动手学CV-目标检测入门教程2:VOC数据集
3.2 目標(biāo)檢測(cè)數(shù)據(jù)集VOC
本文來自開源組織 DataWhale 🐳 CV小組創(chuàng)作的目標(biāo)檢測(cè)入門教程。
對(duì)應(yīng)開源項(xiàng)目 《動(dòng)手學(xué)CV-Pytorch》 的第3章的內(nèi)容,教程中涉及的代碼也可以在項(xiàng)目中找到,后續(xù)會(huì)持續(xù)更新更多的優(yōu)質(zhì)內(nèi)容,歡迎??。
如果使用我們教程的內(nèi)容或圖片,請(qǐng)?jiān)谖恼滦涯课恢米⒚魑覀兊膅ithub主頁鏈接:https://github.com/datawhalechina/dive-into-cv-pytorch
3.2.1 VOC數(shù)據(jù)集簡介
VOC數(shù)據(jù)集是目標(biāo)檢測(cè)領(lǐng)域最常用的標(biāo)準(zhǔn)數(shù)據(jù)集之一,幾乎所有檢測(cè)方向的論文,如faster_rcnn、yolo、SSD等都會(huì)給出其在VOC數(shù)據(jù)集上訓(xùn)練并評(píng)測(cè)的效果。因此我們我們的教程也基于VOC來開展實(shí)驗(yàn),具體地,我們使用VOC2007和VOC2012這兩個(gè)最流行的版本作為訓(xùn)練和測(cè)試的數(shù)據(jù)。
數(shù)據(jù)集類別
VOC數(shù)據(jù)集在類別上可以分為4大類,20小類,其類別信息如圖3-5所示。
圖3-5 VOC數(shù)據(jù)集目標(biāo)類別劃分數(shù)據(jù)集量級(jí)
VOC數(shù)量集圖像和目標(biāo)數(shù)量的基本信息如下圖3-6所示:
圖3-6 VOC數(shù)據(jù)集數(shù)據(jù)量級(jí)對(duì)比其中,Images表示圖片數(shù)量,Objects表示目標(biāo)數(shù)量
數(shù)據(jù)集下載
VOC官網(wǎng)經(jīng)常上不去,為確保后續(xù)實(shí)驗(yàn)準(zhǔn)確且順利的進(jìn)行,大家可以點(diǎn)擊這里的百度云鏈接進(jìn)行下載:
🐳 VOC百度云下載鏈接 解壓碼(7aek)
下載后放到dataset目錄下解壓即可
下面是通過官網(wǎng)下載的步驟:
進(jìn)入VOC官網(wǎng)鏈接:http://host.robots.ox.ac.uk/pascal/VOC/
在圖3-7所示區(qū)域找到歷年VOC挑戰(zhàn)賽鏈接,比如選擇VOC2012.
數(shù)據(jù)集說明
將下載得到的壓縮包解壓,可以得到如圖3-9所示的一系列文件夾,由于VOC數(shù)據(jù)集不僅被拿來做目標(biāo)檢測(cè),也可以拿來做分割等任務(wù),因此除了目標(biāo)檢測(cè)所需的文件之外,還包含分割任務(wù)所需的文件,比如SegmentationClass,SegmentationObject,這里,我們主要對(duì)目標(biāo)檢測(cè)任務(wù)涉及到的文件進(jìn)行介紹。
圖3-9 VOC壓縮包解壓所得文件夾示例1.JPEGImages
這個(gè)文件夾中存放所有的圖片,包括訓(xùn)練驗(yàn)證測(cè)試用到的所有圖片。
2.ImageSets
這個(gè)文件夾中包含三個(gè)子文件夾,Layout、Main、Segmentation
-
Layout文件夾中存放的是train,valid,test和train+valid數(shù)據(jù)集的文件名
-
Segmentation文件夾中存放的是分割所用train,valid,test和train+valid數(shù)據(jù)集的文件名
-
Main文件夾中存放的是各個(gè)類別所在圖片的文件名,比如cow_val,表示valid數(shù)據(jù)集中,包含有cow類別目標(biāo)的圖片名稱。
3.Annotations
Annotation文件夾中存放著每張圖片相關(guān)的標(biāo)注信息,以xml格式的文件存儲(chǔ),可以通過記事本或者瀏覽器打開,我們以000001.jpg這張圖片為例說明標(biāo)注文件中各個(gè)屬性的含義,見圖3-10。
圖3-10 VOC數(shù)據(jù)集000001.jpg圖片(左)和標(biāo)注信息(右)猛一看去,內(nèi)容又多又復(fù)雜,其實(shí)仔細(xì)研究一下,只有紅框區(qū)域內(nèi)的內(nèi)容是我們真正需要關(guān)注的。
filename:圖片名稱
size:圖片寬高,
depth表示圖片通道數(shù)
object:表示目標(biāo),包含下面兩部分內(nèi)容。
-
首先是目標(biāo)類別name為dog。pose表示目標(biāo)姿勢(shì)為left,truncated表示是否是一個(gè)被截?cái)嗟哪繕?biāo),1表示是,0表示不是,在這個(gè)例子中,只露出狗頭部分,所以truncated為1。difficult為0表示此目標(biāo)不是一個(gè)難以識(shí)別的目標(biāo)。
-
然后就是目標(biāo)的bbox信息,可以看到,這里是以[xmin,ymin,xmax,ymax]格式進(jìn)行標(biāo)注的,分別表示dog目標(biāo)的左上角和右下角坐標(biāo)。
3.2.2 VOC數(shù)據(jù)集的dataloader的構(gòu)建
1. 數(shù)據(jù)集準(zhǔn)備
根據(jù)上面的介紹可以看出,VOC數(shù)據(jù)集的存儲(chǔ)格式還是比較復(fù)雜的,為了后面訓(xùn)練中的讀取代碼更加簡潔,這里我們準(zhǔn)備了一個(gè)預(yù)處理腳本create_data_lists.py。
該腳本的作用是進(jìn)行一系列的數(shù)據(jù)準(zhǔn)備工作,主要是提前將記錄標(biāo)注信息的xml文件(Annotations)進(jìn)行解析,并將信息整理到j(luò)son文件之中,這樣在運(yùn)行訓(xùn)練腳本時(shí),只需簡單的從json文件中讀取已經(jīng)按想要的格式存儲(chǔ)好的標(biāo)簽信息即可。
注: 這樣的預(yù)處理并不是必須的,和算法或數(shù)據(jù)集本身均無關(guān)系,只是取決于開發(fā)者的代碼習(xí)慣,不同檢測(cè)框架的處理方法也是不一致的。
可以看到,create_data_lists.py腳本僅有幾行代碼,其內(nèi)部調(diào)用了utils.py中的create_data_lists方法:
"""pythoncreate_data_lists """ from utils import create_data_listsif __name__ == '__main__':# voc07_path,voc12_path為我們訓(xùn)練測(cè)試所需要用到的數(shù)據(jù)集,output_folder為我們生成構(gòu)建dataloader所需文件的路徑# 參數(shù)中涉及的路徑以個(gè)人實(shí)際路徑為準(zhǔn),建議將數(shù)據(jù)集放到dataset目錄下,和教程保持一致create_data_lists(voc07_path='../../../dataset/VOCdevkit/VOC2007',voc12_path='../../../dataset/VOCdevkit/VOC2012',output_folder='../../../dataset/VOCdevkit')設(shè)置好對(duì)應(yīng)路徑后,我們運(yùn)行數(shù)據(jù)集準(zhǔn)備腳本:
tiny_detector_demo$ python create_data_lists.py
很快啊!dataset/VOCdevkit目錄下就生成了若干json文件,這些文件會(huì)在后面訓(xùn)練中真正被用到。
不妨手動(dòng)打開這些json文件,看下都記錄了哪些信息。
下面來介紹一下parse_annotation函數(shù)內(nèi)部都做了什么,json中又記錄了哪些信息。這部分作為選學(xué),不感興趣可以跳過,只要你已經(jīng)明確了json中記錄的信息的含義。
代碼閱讀可以參照注釋,建議配圖3-11一起食用:
"""pythonxml文件解析 """import json import os import torch import random import xml.etree.ElementTree as ET #解析xml文件所用工具 import torchvision.transforms.functional as FT#GPU設(shè)置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Label map #voc_labels為VOC數(shù)據(jù)集中20類目標(biāo)的類別名稱 voc_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable','dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')#創(chuàng)建label_map字典,用于存儲(chǔ)類別和類別索引之間的映射關(guān)系。比如:{1:'aeroplane', 2:'bicycle',......} label_map = {k: v + 1 for v, k in enumerate(voc_labels)} #VOC數(shù)據(jù)集默認(rèn)不含有20類目標(biāo)中的其中一類的圖片的類別為background,類別索引設(shè)置為0 label_map['background'] = 0#將映射關(guān)系倒過來,{類別名稱:類別索引} rev_label_map = {v: k for k, v in label_map.items()} # Inverse mapping#解析xml文件,最終返回這張圖片中所有目標(biāo)的標(biāo)注框及其類別信息,以及這個(gè)目標(biāo)是否是一個(gè)difficult目標(biāo) def parse_annotation(annotation_path):#解析xmltree = ET.parse(annotation_path)root = tree.getroot()boxes = list() #存儲(chǔ)bboxlabels = list() #存儲(chǔ)bbox對(duì)應(yīng)的labeldifficulties = list() #存儲(chǔ)bbox對(duì)應(yīng)的difficult信息#遍歷xml文件中所有的object,前面說了,有多少個(gè)object就有多少個(gè)目標(biāo)for object in root.iter('object'):#提取每個(gè)object的difficult、label、bbox信息difficult = int(object.find('difficult').text == '1')label = object.find('name').text.lower().strip()if label not in label_map:continuebbox = object.find('bndbox')xmin = int(bbox.find('xmin').text) - 1ymin = int(bbox.find('ymin').text) - 1xmax = int(bbox.find('xmax').text) - 1ymax = int(bbox.find('ymax').text) - 1#存儲(chǔ)boxes.append([xmin, ymin, xmax, ymax])labels.append(label_map[label])difficulties.append(difficult)#返回包含圖片標(biāo)注信息的字典return {'boxes': boxes, 'labels': labels, 'difficulties': difficulties}看了上面的代碼如果還不太明白,試試結(jié)合這張圖理解下:
圖3-11 xml解析流程圖接下來看一下create_data_lists函數(shù)在做什么,建議配圖3-12一起食用:
"""python分別讀取train和valid的圖片和xml信息,創(chuàng)建用于訓(xùn)練和測(cè)試的json文件 """ def create_data_lists(voc07_path, voc12_path, output_folder):"""Create lists of images, the bounding boxes and labels of the objects in these images, and save these to file.:param voc07_path: path to the 'VOC2007' folder:param voc12_path: path to the 'VOC2012' folder:param output_folder: folder where the JSONs must be saved"""#獲取voc2007和voc2012數(shù)據(jù)集的絕對(duì)路徑voc07_path = os.path.abspath(voc07_path)voc12_path = os.path.abspath(voc12_path)train_images = list()train_objects = list()n_objects = 0# Training datafor path in [voc07_path, voc12_path]:# Find IDs of images in training data#獲取訓(xùn)練所用的train和val數(shù)據(jù)的圖片idwith open(os.path.join(path, 'ImageSets/Main/trainval.txt')) as f:ids = f.read().splitlines()#根據(jù)圖片id,解析圖片的xml文件,獲取標(biāo)注信息for id in ids:# Parse annotation's XML fileobjects = parse_annotation(os.path.join(path, 'Annotations', id + '.xml'))if len(objects['boxes']) == 0: #如果沒有目標(biāo)則跳過continuen_objects += len(objects) #統(tǒng)計(jì)目標(biāo)總數(shù)train_objects.append(objects) #存儲(chǔ)每張圖片的標(biāo)注信息到列表train_objectstrain_images.append(os.path.join(path, 'JPEGImages', id + '.jpg')) #存儲(chǔ)每張圖片的路徑到列表train_images,用于讀取圖片assert len(train_objects) == len(train_images) #檢查圖片數(shù)量和標(biāo)注信息量是否相等,相等才繼續(xù)執(zhí)行程序# Save to file#將訓(xùn)練數(shù)據(jù)的圖片路徑,標(biāo)注信息,類別映射信息,分別保存為json文件with open(os.path.join(output_folder, 'TRAIN_images.json'), 'w') as j:json.dump(train_images, j)with open(os.path.join(output_folder, 'TRAIN_objects.json'), 'w') as j:json.dump(train_objects, j)with open(os.path.join(output_folder, 'label_map.json'), 'w') as j:json.dump(label_map, j) # save label map tooprint('\nThere are %d training images containing a total of %d objects. Files have been saved to %s.' % (len(train_images), n_objects, os.path.abspath(output_folder)))#與Train data一樣,目的是將測(cè)試數(shù)據(jù)的圖片路徑,標(biāo)注信息,類別映射信息,分別保存為json文件,參考上面的注釋理解# Test datatest_images = list()test_objects = list()n_objects = 0# Find IDs of images in the test datawith open(os.path.join(voc07_path, 'ImageSets/Main/test.txt')) as f:ids = f.read().splitlines()for id in ids:# Parse annotation's XML fileobjects = parse_annotation(os.path.join(voc07_path, 'Annotations', id + '.xml'))if len(objects) == 0:continuetest_objects.append(objects)n_objects += len(objects)test_images.append(os.path.join(voc07_path, 'JPEGImages', id + '.jpg'))assert len(test_objects) == len(test_images)# Save to filewith open(os.path.join(output_folder, 'TEST_images.json'), 'w') as j:json.dump(test_images, j)with open(os.path.join(output_folder, 'TEST_objects.json'), 'w') as j:json.dump(test_objects, j)print('\nThere are %d test images containing a total of %d objects. Files have been saved to %s.' % (len(test_images), n_objects, os.path.abspath(output_folder)))同樣,建議配圖食用:
圖3-12 數(shù)據(jù)準(zhǔn)備流程圖(以train_dataset為例)到這里,我們的訓(xùn)練數(shù)據(jù)就準(zhǔn)備好了,接下來開始一步步構(gòu)建訓(xùn)練所需的dataloader吧!
2.構(gòu)建dataloader
在這里,我們假設(shè)你對(duì)Pytorch的 Dataset 和 DataLoader 兩個(gè)概念有最基本的了解。
如果沒有,也不必?fù)?dān)心,你可以先閱讀一下第2-1節(jié)數(shù)據(jù)讀取與數(shù)據(jù)擴(kuò)增,進(jìn)行簡單的了解。
下面開始介紹構(gòu)建dataloader的相關(guān)代碼:
1.首先了解一下訓(xùn)練的時(shí)候在哪里定義了dataloader以及是如何定義的。
以下是train.py中的部分代碼段:
#train_dataset和train_loader的實(shí)例化train_dataset = PascalVOCDataset(data_folder,split='train',keep_difficult=keep_difficult)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=train_dataset.collate_fn, num_workers=workers,pin_memory=True) # note that we're passing the collate function here可以看到,首先需要實(shí)例化PascalVOCDataset類得到train_dataset,然后將train_dataset傳入torch.utils.data.DataLoader,進(jìn)而得到train_loader。
2.接下來看一下PascalVOCDataset是如何定義的。
代碼位于 datasets.py 腳本中,可以看到,PascalVOCDataset繼承了torch.utils.data.Dataset,然后重寫了__init__ , __getitem__, __len__ 和 collate_fn 四個(gè)方法,這也是我們?cè)跇?gòu)建自己的dataset的時(shí)候需要經(jīng)常做的工作,配合下面注釋理解代碼:
"""pythonPascalVOCDataset具體實(shí)現(xiàn)過程 """ import torch from torch.utils.data import Dataset import json import os from PIL import Image from utils import transformclass PascalVOCDataset(Dataset):"""A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches."""#初始化相關(guān)變量#讀取images和objects標(biāo)注信息def __init__(self, data_folder, split, keep_difficult=False):""":param data_folder: folder where data files are stored:param split: split, one of 'TRAIN' or 'TEST':param keep_difficult: keep or discard objects that are considered difficult to detect?"""self.split = split.upper() #保證輸入為純大寫字母,便于匹配{'TRAIN', 'TEST'}assert self.split in {'TRAIN', 'TEST'}self.data_folder = data_folderself.keep_difficult = keep_difficult# Read data fileswith open(os.path.join(data_folder, self.split + '_images.json'), 'r') as j:self.images = json.load(j)with open(os.path.join(data_folder, self.split + '_objects.json'), 'r') as j:self.objects = json.load(j)assert len(self.images) == len(self.objects)#循環(huán)讀取image及對(duì)應(yīng)objects#對(duì)讀取的image及objects進(jìn)行tranform操作(數(shù)據(jù)增廣)#返回PIL格式圖像,標(biāo)注框,標(biāo)注框?qū)?yīng)的類別索引,對(duì)應(yīng)的difficult標(biāo)志(True or False)def __getitem__(self, i):# Read image#*需要注意,在pytorch中,圖像的讀取要使用Image.open()讀取成PIL格式,不能使用opencv#*由于Image.open()讀取的圖片是四通道的(RGBA),因此需要.convert('RGB')轉(zhuǎn)換為RGB通道image = Image.open(self.images[i], mode='r')image = image.convert('RGB')# Read objects in this image (bounding boxes, labels, difficulties)objects = self.objects[i]boxes = torch.FloatTensor(objects['boxes']) # (n_objects, 4)labels = torch.LongTensor(objects['labels']) # (n_objects)difficulties = torch.ByteTensor(objects['difficulties']) # (n_objects)# Discard difficult objects, if desired#如果self.keep_difficult為False,即不保留difficult標(biāo)志為True的目標(biāo)#那么這里將對(duì)應(yīng)的目標(biāo)刪去if not self.keep_difficult:boxes = boxes[1 - difficulties]labels = labels[1 - difficulties]difficulties = difficulties[1 - difficulties]# Apply transformations#對(duì)讀取的圖片應(yīng)用transformimage, boxes, labels, difficulties = transform(image, boxes, labels, difficulties, split=self.split)return image, boxes, labels, difficulties#獲取圖片的總數(shù),用于計(jì)算batch數(shù)def __len__(self):return len(self.images)#我們知道,我們輸入到網(wǎng)絡(luò)中訓(xùn)練的數(shù)據(jù)通常是一個(gè)batch一起輸入,而通過__getitem__我們只讀取了一張圖片及其objects信息#如何將讀取的一張張圖片及其object信息整合成batch的形式呢?#collate_fn就是做這個(gè)事情,#對(duì)于一個(gè)batch的images,collate_fn通過torch.stack()將其整合成4維tensor,對(duì)應(yīng)的objects信息分別用一個(gè)list存儲(chǔ)def collate_fn(self, batch):"""Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).This describes how to combine these tensors of different sizes. We use lists.Note: this need not be defined in this Class, can be standalone.:param batch: an iterable of N sets from __getitem__():return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties"""images = list()boxes = list()labels = list()difficulties = list()for b in batch:images.append(b[0])boxes.append(b[1])labels.append(b[2])difficulties.append(b[3])#(3,224,224) -> (N,3,224,224)images = torch.stack(images, dim=0)return images, boxes, labels, difficulties # tensor (N, 3, 224, 224), 3 lists of N tensors each3.關(guān)于數(shù)據(jù)增強(qiáng)
到這里為止,我們的dataset就算是構(gòu)建好了,已經(jīng)可以傳給torch.utils.data.DataLoader來獲得用于輸入網(wǎng)絡(luò)訓(xùn)練的數(shù)據(jù)了。
但是不急,構(gòu)建dataset中有個(gè)很重要的一步我們上面只是提及了一下,那就是transform操作(數(shù)據(jù)增強(qiáng))。
也就是這一行代碼
image, boxes, labels, difficulties = transform(image, boxes, labels, difficulties, split=self.split)這部分比較重要,但是涉及代碼稍多,對(duì)于基礎(chǔ)較薄弱的伙伴可以作為選學(xué)內(nèi)容,后面再認(rèn)真讀代碼。你只需知道,同分類網(wǎng)絡(luò)一樣,訓(xùn)練目標(biāo)檢測(cè)網(wǎng)絡(luò)同樣需要進(jìn)行數(shù)據(jù)增強(qiáng),這對(duì)提升網(wǎng)絡(luò)精度和泛化能力很有幫助。
需要注意的是,涉及位置變化的數(shù)據(jù)增強(qiáng)方法,同樣需要對(duì)目標(biāo)框進(jìn)行一致的處理,因此目標(biāo)檢測(cè)框架的數(shù)據(jù)處理這部分的代碼量通常都不小,且比較容易出bug。這里為了降低代碼的難度,我們只是使用了幾種比較簡單的數(shù)據(jù)增強(qiáng)。
transform 函數(shù)的具體代碼實(shí)現(xiàn)位于 utils.py 中,下面簡單進(jìn)行講解:
"""pythontransform操作是訓(xùn)練模型中一項(xiàng)非常重要的工作,其中不僅包含數(shù)據(jù)增強(qiáng)以提升模型性能的相關(guān)操作,也包含如數(shù)據(jù)類型轉(zhuǎn)換(PIL to Tensor)、歸一化(Normalize)這些必要操作。 """ import json import os import torch import random import xml.etree.ElementTree as ET import torchvision.transforms.functional as FT""" 可以看到,transform分為TRAIN和TEST兩種模式,以本實(shí)驗(yàn)為例:在TRAIN時(shí)進(jìn)行的transform有: 1.以隨機(jī)順序改變圖片亮度,對(duì)比度,飽和度和色相,每種都有50%的概率被執(zhí)行。photometric_distort 2.擴(kuò)大目標(biāo),expand 3.隨機(jī)裁剪圖片,random_crop 4.0.5的概率進(jìn)行圖片翻轉(zhuǎn),flip *注意:a. 第一種transform屬于像素級(jí)別的圖像增強(qiáng),目標(biāo)相對(duì)于圖片的位置沒有改變,因此bbox坐標(biāo)不需要變化。但是2,3,4,5都屬于圖片的幾何變化,目標(biāo)相對(duì)于圖片的位置被改變,因此bbox坐標(biāo)要進(jìn)行相應(yīng)變化。在TRAIN和TEST時(shí)都要進(jìn)行的transform有: 1.統(tǒng)一圖像大小到(224,224),resize 2.PIL to Tensor 3.歸一化,FT.normalize()注1: resize也是一種幾何變化,要知道應(yīng)用數(shù)據(jù)增強(qiáng)策略時(shí),哪些屬于幾何變化,哪些屬于像素變化 注2: PIL to Tensor操作,normalize操作必須執(zhí)行 """def transform(image, boxes, labels, difficulties, split):"""Apply the transformations above.:param image: image, a PIL Image:param boxes: bounding boxes in boundary coordinates, a tensor of dimensions (n_objects, 4):param labels: labels of objects, a tensor of dimensions (n_objects):param difficulties: difficulties of detection of these objects, a tensor of dimensions (n_objects):param split: one of 'TRAIN' or 'TEST', since different sets of transformations are applied:return: transformed image, transformed bounding box coordinates, transformed labels, transformed difficulties"""#在訓(xùn)練和測(cè)試時(shí)使用的transform策略往往不完全相同,所以需要split變量指明是TRAIN還是TEST時(shí)的transform方法assert split in {'TRAIN', 'TEST'}# Mean and standard deviation of ImageNet data that our base VGG from torchvision was trained on# see: https://pytorch.org/docs/stable/torchvision/models.html#為了防止由于圖片之間像素差異過大而導(dǎo)致的訓(xùn)練不穩(wěn)定問題,圖片在送入網(wǎng)絡(luò)訓(xùn)練之間需要進(jìn)行歸一化#對(duì)所有圖片各通道求mean和std來獲得mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]new_image = imagenew_boxes = boxesnew_labels = labelsnew_difficulties = difficulties# Skip the following operations for evaluation/testingif split == 'TRAIN':# A series of photometric distortions in random order, each with 50% chance of occurrence, as in Caffe reponew_image = photometric_distort(new_image)# Convert PIL image to Torch tensornew_image = FT.to_tensor(new_image)# Expand image (zoom out) with a 50% chance - helpful for training detection of small objects# Fill surrounding space with the mean of ImageNet data that our base VGG was trained onif random.random() < 0.5:new_image, new_boxes = expand(new_image, boxes, filler=mean)# Randomly crop image (zoom in)new_image, new_boxes, new_labels, new_difficulties = random_crop(new_image, new_boxes, new_labels,new_difficulties)# Convert Torch tensor to PIL imagenew_image = FT.to_pil_image(new_image)# Flip image with a 50% chanceif random.random() < 0.5:new_image, new_boxes = flip(new_image, new_boxes)# Resize image to (224, 224) - this also converts absolute boundary coordinates to their fractional formnew_image, new_boxes = resize(new_image, new_boxes, dims=(224, 224))# Convert PIL image to Torch tensornew_image = FT.to_tensor(new_image)# Normalize by mean and standard deviation of ImageNet data that our base VGG was trained onnew_image = FT.normalize(new_image, mean=mean, std=std)return new_image, new_boxes, new_labels, new_difficulties4.最后,構(gòu)建DataLoader
至此,我們已經(jīng)將VOC數(shù)據(jù)轉(zhuǎn)換成了dataset,接下來可以用來創(chuàng)建dataloader,這部分pytorch已經(jīng)幫我們實(shí)現(xiàn)好了,我們只需將創(chuàng)建好的dataset送入即可,注意理解相關(guān)參數(shù)。
"""pythonDataLoader """ #參數(shù)說明: #在train時(shí)一般設(shè)置shufle=True打亂數(shù)據(jù)順序,增強(qiáng)模型的魯棒性 #num_worker表示讀取數(shù)據(jù)時(shí)的線程數(shù),一般根據(jù)自己設(shè)備配置確定(如果是windows系統(tǒng),建議設(shè)默認(rèn)值0,防止出錯(cuò)) #pin_memory,在計(jì)算機(jī)內(nèi)存充足的時(shí)候設(shè)置為True可以加快內(nèi)存中的tensor轉(zhuǎn)換到GPU的速度,具體原因可以百度哈~ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=train_dataset.collate_fn, num_workers=workers,pin_memory=True) # note that we're passing the collate function here3.2.3 小結(jié)
到這里,這一小節(jié)的內(nèi)容就介紹完了。
回顧下,本節(jié)中,我們首先介紹了VOC數(shù)據(jù)集的基本信息以及如何下載,隨后我們介紹了和讀取VOC數(shù)據(jù)集的相關(guān)代碼。
萬事俱備,只欠模型~
總結(jié)
以上是生活随笔為你收集整理的动手学CV-目标检测入门教程2:VOC数据集的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 协议簇: Media Access Co
- 下一篇: sync.Map 源码学习