【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度
生活随笔
收集整理的這篇文章主要介紹了
【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
@Author:Runsen
上次基于CIFAR-10 數(shù)據(jù)集,使用PyTorch ??構(gòu)建圖像分類模型的精確度是60%,對(duì)于如何提升精確度,方法就是常見(jiàn)的transforms圖像數(shù)據(jù)增強(qiáng)手段。
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoaderimport torchvision import torchvision.datasets as datasets import torchvision.transforms as transforms import torchvision.utils as vutilsimport numpy as np import os import warnings from matplotlib import pyplot as plt warnings.filterwarnings('ignore')` device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')加載數(shù)據(jù)集
# number of images in one forward and backward pass batch_size = 128# number of subprocesses used for data loading # Normally do not use it if your os is windows num_workers = 2train_dataset = datasets.CIFAR10('./data/CIFAR10/', train = True, download = True, transform = transform_train)train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers)val_dataset = datasets.CIFAR10('./data/CIFAR10', train = True, transform = transform_test)val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)test_dataset = datasets.CIFAR10('./data/CIFAR10', train = False, transform = transform_test)test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers)# declare classes in CIFAR10 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')之前的transform ’只是進(jìn)行了縮放和歸一,在這里添加RandomCrop和RandomHorizontalFlip
# define a transform to normalize the datatransform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(), # converting images to tensortransforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) # if the image dataset is black and white image, there can be just one number. ])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) ])可視化具體的圖像
# function that will be used for visualizing the datadef imshow(img):img = img / 2 + 0.5 # unnormalizeplt.imshow(np.transpose(img, (1, 2, 0))) # convert from Tensor image# obtain one batch of imges from train dataset dataiter = iter(train_loader) images, labels = dataiter.next() images = images.numpy() # convert images to numpy for display# plot the images in one batch with the corresponding labels fig = plt.figure(figsize = (25, 4))# display images for idx in np.arange(10):ax = fig.add_subplot(1, 10, idx+1, xticks=[], yticks=[])imshow(images[idx])ax.set_title(classes[labels[idx]])建立常見(jiàn)的CNN模型
# define the CNN architectureclass CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.main = nn.Sequential(# 3x32x32nn.Conv2d(in_channels = 3, out_channels = 32, kernel_size = 3, padding = 1), # 3x32x32 (O = (N+2P-F/S)+1)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size = 2, stride = 2), # 32x16x16nn.BatchNorm2d(32),nn.Conv2d(32, 64, kernel_size = 3, padding = 1), # 32x16x16nn.ReLU(inplace=True),nn.MaxPool2d(2, 2), # 64x8x8nn.BatchNorm2d(64),nn.Conv2d(64, 128, 3, padding = 1), # 64x8x8nn.ReLU(inplace=True),nn.MaxPool2d(2, 2), # 128x4x4nn.BatchNorm2d(128),)self.fc = nn.Sequential(nn.Linear(128*4*4, 1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, 10))def forward(self, x):# Conv and Poolilng layersx = self.main(x)# Flatten before Fully Connected layersx = x.view(-1, 128*4*4) # Fully Connected Layerx = self.fc(x)return xcnn = CNN().to(device) cnn
torch.nn.CrossEntropyLoss對(duì)輸出概率介于0和1之間的分類模型進(jìn)行分類。
訓(xùn)練模型
# 超參數(shù):Hyper Parameters learning_rate = 0.001 train_losses = [] val_losses = []# Loss function and Optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(cnn.parameters(), lr = learning_rate)# define train function that trains the model using a CIFAR10 datasetdef train(model, epoch, num_epochs):model.train()total_batch = len(train_dataset) // batch_sizefor i, (images, labels) in enumerate(train_loader):X = images.to(device)Y = labels.to(device)### forward pass and loss calculation# forward passpred = model(X)#c alculation of loss valuecost = criterion(pred, Y)### backward pass and optimization# gradient initializationoptimizer.zero_grad()# backward passcost.backward()# parameter updateoptimizer.step()# training statsif (i+1) % 100 == 0:print('Train, Epoch [%d/%d], lter [%d/%d], Loss: %.4f' % (epoch+1, num_epochs, i+1, total_batch, np.average(train_losses)))train_losses.append(cost.item())n# def the validation function that validates the model using CIFAR10 datasetdef validation(model, epoch, num_epochs):model.eval()total_batch = len(val_dataset) // batch_sizefor i, (images, labels) in enumerate(val_loader):X = images.to(device)Y = labels.to(device)with torch.no_grad():pred = model(X)cost = criterion(pred, Y)if (i+1) % 100 == 0:print("Validation, Epoch [%d/%d], lter [%d/%d], Loss: %.4f"% (epoch+1, num_epochs, i+1, total_batch, np.average(val_losses)))val_losses.append(cost.item())def plot_losses(train_losses, val_losses):plt.figure(figsize=(5, 5))plt.plot(train_losses, label='Train', alpha=0.5)plt.plot(val_losses, label='Validation', alpha=0.5)plt.xlabel('Epochs')plt.ylabel('Losses')plt.legend()plt.grid(b=True)plt.title('CIFAR 10 Train/Val Losses Over Epoch')plt.show()num_epochs = 20 for epoch in range(num_epochs):train(cnn, epoch, num_epochs)validation(cnn, epoch, num_epochs)torch.save(cnn.state_dict(), './data/Tutorial_3_CNN_Epoch_{}.pkl'.format(epoch+1))plot_losses(train_losses, val_losses)
測(cè)試模型
經(jīng)過(guò)圖像數(shù)據(jù)增強(qiáng)。模型從60提升到了84。
測(cè)試模型在哪些類上表現(xiàn)良好,
class_correct = list(0. for i in range(10)) class_total = list(0. for i in range(10))with torch.no_grad():for data in test_loader:images, labels = dataimages = images.to(device)labels = labels.to(device)outputs = cnn(images)_, predicted = torch.max(outputs, 1)c = (predicted == labels).squeeze()for i in range(4):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1for i in range(10):print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))總結(jié)
以上是生活随笔為你收集整理的【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: 开个便利店需要什么条件 最重要的其实还是
- 下一篇: b2c是什么意思