PyTorch 笔记(20)— torchvision 的 datasets、transforms 数据预览和加载、模型搭建(torch.nn.Conv2d/MaxPool2d/Dropout)
計算機視覺是深度學習中最重要的一類應用,為了方便研究者使用,PyTorch 團隊專門開發了一個視覺工具包torchvision,這個包獨立于 PyTorch,需通過 pip instal torchvision 安裝。
torchvision 主要包含三部分:
models:提供深度學習中各種經典網絡的網絡結構以及預訓練好的模型,包括AlexNet、VGG系列、ResNet系列、Inception系列等;datasets: 提供常用的數據集加載,設計上都是繼承torch.utils.data.Dataset,主要包括MNIST、CIFAR10/100、ImageNet、COCO等;transforms:提供常用的數據預處理操作,主要包括對Tensor以及PIL Image對象的操作;
from torchvision import models
from torch import nn
from torchvision import datasets'''加載預訓練好的模型,如果不存在會進行下載
預訓練好的模型保存在 ~/.torch/models/下面'''
resnet34 = models.squeezenet1_1(pretrained=True, num_classes=1000)'''修改最后的全連接層為10分類問題(默認是ImageNet上的1000分類)'''
resnet34.fc=nn.Linear(512, 10)'''加上transform'''
transform = T.Compose([T.ToTensor(),T.Normalize(mean=[0.4,], std=[0.2,]),
])
'''
# 指定數據集路徑為data,如果數據集不存在則進行下載
# 通過train=False獲取測試集
'''
dataset = datasets.MNIST('data/', download=True, train=False, transform=transform)
Transforms 中涵蓋了大部分對 Tensor 和 PIL Image 的常用處理,這些已在上文提到,這里就不再詳細介紹。需要注意的是轉換分為兩步,
- 第一步:構建轉換操作,例如
transf = transforms.Normalize(mean=x, std=y), - 第二步:執行轉換操作,例如
output = transf(input)。另外還可將多個處理操作用Compose拼接起來,形成一個處理轉換流程。
from torchvision import transforms
to_pil = transforms.ToPILImage()
to_pil(t.randn(3, 64, 64))
輸出隨機噪聲,待補充:
torchvision 還提供了兩個常用的函數。
- 一個是
make_grid,它能將多張圖片拼接成一個網格中; - 另一個是
save_img,它能將Tensor保存成圖片。
len(dataset) # 10000
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)
from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
img = make_grid(next(dataiter)[0], 4) # 拼成4*4網格圖片,且會轉成3通道
to_img(img)
輸出:(待補充)
save_image(img, 'a.png')
Image.open('a.png')
輸出:(待補充)
1. datasets
使用 torchvision.datasets 可以輕易實現對這些數據集的訓練集和測試集的下載,只需要使用 torchvision.datasets 再加上需要下載的數據集的名稱就可以了。
比如在這個問題中我們要用到手寫數字數據集,它的名稱是 MNIST,那么實現下載的代碼就是
torchvision.datasets.MNIST。其他常用的數據集如 COCO、ImageNet、CIFCAR 等都可以通過這個方法快速下載和載入。實現數據集下載的代碼如下:
import torch as t
from torchvision import datasets, transformsdata_train = datasets.MNIST(root="./data", transform=transform, train=True, download=True)
data_test = datasets.MNIST(root="./data", transform=transform, train=False)
其中,
root用于指定數據集在下載之后的存放路徑,這里存放在根目錄下的data文件夾中;transform用于指定導入數據集時需要對數據進行哪種變換操作;
注意,要提前定義這些變換操作;train 用于指定在數據集下載完成后需要載入哪部分數據,
- 如果設置為
True,則說明載入的是該數據集的訓練集部分; - 如果設置為
False,則說明載入的是該數據集的測試集部分;
2. transforms
在計算機視覺中處理的數據集有很大一部分是圖片類型的,而在 PyTorch 中實際進行計算的是 Tensor 數據類型的變量,所以我們首先需要解決的是數據類型轉換的問題,如果獲取的數據是格式或者大小不一的圖片,則還需要進行歸一化和大小縮放等操作,慶幸的是,這些方法在 torch.transforms 中都能找到。
在 torch.transforms 中有大量的數據變換類,其中有很大一部分可以用于實現數據增強(DataArgumentation)。若在我們需要解決的問題上能夠參與到模型訓練中的圖片數據非常有限,則這時就要通過對有限的圖片數據進行各種變換,來生成新的訓練集了,這些變換可以是縮小或者放大圖片的大小、對圖片進行水平或者垂直翻轉等,都是數據增強的方法。
不過在手寫數字識別的問題上可以不使用數據增強的方法,因為可用于模型訓練的數據已經足夠了。對數據進行載入及有相應變化的代碼如下:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
我們可以將以上代碼中的 torchvision.transforms.Compose 類看作一種容器,它能夠同時對多種數據變換進行組合。傳入的參數是一個列表,列表中的元素就是對載入的數據進行的各種變換操作。
在以上代碼中,在 torchvision.transforms.Compose 中只使用了一個類型的轉換變換 transforms.ToTensor 和一個數據標準化變換transforms.Normalize。
這里使用的標準化變換也叫作標準差變換法,這種方法需要使用原始數據的均值(Mean)和標準差(StandardDeviation)來進行數據的標準化,在經過標準化變換之后,數據全部符合均值為0、標準差為1的標準正態分布。
下面看看在 torchvision.transforms 中常用的數據變換操作。
torchvision.transforms.Resize:用于對載入的圖片數據按我們需求的大小進行縮放。傳遞給這個類的參數可以是一個整型數據,也可以是一個類似于(h,w)的序列,其中,h代表高度,w代表寬度,但是如果使用的是一個整型數據,那么表示縮放的寬度和高度都是這個整型數據的值。torchvision.transforms.Scale:用于對載入的圖片數據按我們需求的大小進行縮放,用法和
torchvision.transforms.Resize類似。torchvision.transforms.CenterCrop:用于對載入的圖片以圖片中心為參考點,按我們需要的大小進行裁剪。傳遞給這個類的參數可以是一個整型數據,也可以是一個類似于(h,w)的序列。*torchvision.transforms.RandomCrop:用于對載入的圖片按我們需要的大小進行隨機裁剪。傳遞給這個類的參數可以是一個整型數據,也可以是一個類似于(h,w)的序列。torchvision.transforms.RandomHorizontalFlip:用于對載入的圖片按隨機概率進行水平翻轉。我們可以通過傳遞給這個類的參數自定義隨機概率,如果沒有定義,則使用默認的概率值 0.5。torchvision.transforms.RandomVerticalFlip:用于對載入的圖片按隨機概率進行垂直翻轉。我們可以通過傳遞給這個類的參數自定義隨機概率,如果沒有定義,則使用默認的概率值 0.5。torchvision.transforms.ToTensor:用于對載入的圖片數據進行類型轉換,將之前構成PIL圖片的數據轉換成Tensor數據類型的變量,讓PyTorch能夠對其進行計算和處理。torchvision.transforms.ToPILImage:用于將Tensor變量的數據轉換成PIL圖片數據,主要是為了方便圖片內容的顯示。
3. 數據預覽和加載
在數據下載完成并且載入后,我們還需要對數據進行裝載。我們可以將數據的載入理解為對圖片的處理,在處理完成后,我們就需要將這些圖片打包好送給我們的模型進行訓練了,而裝載就是這個打包
的過程。
在裝載時通過 batch_size 的值來確認每個包的大小,通過 shuffle 的值來確認是否在裝載的過程中打亂圖片的順序。裝載圖片的代碼如下:
data_loader_train = torch.utils.data.DataLoader(dataset=data_train, batch_size = 64,shuffle = True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test, batch_size=64,shuffle = True)
對數據的裝載使用的是 torch.utils.data.DataLoader 類,類中的
dataset參數用于指定我們載入的數據集名稱;batch_size參數設置了每個包中的圖片數據個數,代碼中的值是 64,所以在每個包中會包含64張圖片;shuffle參數設置為True,在裝載的過程會將數據隨機打亂順序并進行打包;
在裝載完成后,我們可以選取其中一個批次的數據進行預覽。進行數據預覽的代碼如下:
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)img = img.numpy().transpose(1,2,0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print([labels[i] for i in range(64))
在以上代碼中使用了 iter 和 next 來獲取一個批次的圖片數據和其對應的圖片標簽,然后使用torchvision.utils 中的 make_grid 類方法將一個批次的圖片構造成網格模式。
需要傳遞給 torchvision.utils.make_grid 的參數就是一個批次的裝載數據,每個批次的裝載數據都是 4 維的,維度的構成從前往后分別為 batch_size 、channel 、height 和 weight ,分別對應一個批次中的數據個數、每張圖片的色彩通道數、每張圖片的高度和寬度。
在通過 torchvision.utils.make_grid 之后,圖片的維度變成了( channel , height , weight ),這個批次的圖片全部被整合到了一起,所以在這個維度中對應的值也和之前不一樣了,但是色彩通道數保持不變。
若我們想使用Matplotlib將數據顯示成正常的圖片形式,則使用的數據首先必須是數組,其次這個數組的維度必須是(height,weight,channel),即色彩通道數在最后面。所以我們要通過 numpy 和 transpose 完成原始數據類型的轉換和數據維度的交換,這樣才能夠使用Matplotlib繪制出正確的圖像。
4. 模型搭建和參數優化
(1)torch.nn.Conv2d:用于搭建卷積神經網絡的卷積層,主要的輸入參數有輸入通道數、輸出通道數、卷積核大小、卷積核移動步長和Paddingde值。其中,輸入通道數的數據類型是整型,用于確定輸入數據的層數;輸出通道數的數據類型也是整型,用于確定輸出數據的層數;卷積核大小的數據類型是整型,用于確定卷積核的大小;卷積核移動步長的數據類型是整型,用于確定卷積核每次滑動的步長;Paddingde 的數據類型是整型,值為0時表示不進行邊界像素
的填充,如果值大于0,那么增加數字所對應的邊界像素層數。
(2)torch.nn.MaxPool2d:用于實現卷積神經網絡中的最大池化層,主要的輸入參數是池化窗口大小、池化窗口移動步長和Paddingde值。同樣,池化窗口大小的數據類型是整型,用于確定池化窗口的大小。池化窗口步長的數據類型也是整型,用于確定池化窗口每次移動的步長。Paddingde值和在torch.nn.Conv2d中定義的Paddingde值的用法和意義是一樣的。
(3)torch.nn.Dropout:torch.nn.Dropout類用于防止卷積神經網絡在訓練的過程中發生過擬合,其工作原理簡單來說就是在模型訓練的過程中,以一定的隨機概率將卷積神經網絡模型的部分參數歸零,以達到減少相鄰兩層神經連接的目的。圖 6-3顯示了 Dropout方法的效果。
總結
以上是生活随笔為你收集整理的PyTorch 笔记(20)— torchvision 的 datasets、transforms 数据预览和加载、模型搭建(torch.nn.Conv2d/MaxPool2d/Dropout)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: PyTorch 笔记(19)— Tens
- 下一篇: 真得有紫薇圣人吗!