【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...
生活随笔
收集整理的這篇文章主要介紹了
【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...
小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
「@Author:Runsen」
上次基于CIFAR-10 數(shù)據(jù)集,使用PyTorch構(gòu)建圖像分類模型的精確度是60%,對(duì)于如何提升精確度,方法就是常見的transforms圖像數(shù)據(jù)增強(qiáng)手段。
import?torch import?torch.nn?as?nn import?torch.optim?as?optim from?torch.utils.data?import?DataLoaderimport?torchvision import?torchvision.datasets?as?datasets import?torchvision.transforms?as?transforms import?torchvision.utils?as?vutilsimport?numpy?as?np import?os import?warnings from?matplotlib?import?pyplot?as?plt warnings.filterwarnings('ignore')` device?=?torch.device('cuda'?if?torch.cuda.is_available()?else?'cpu')加載數(shù)據(jù)集
#?number?of?images?in?one?forward?and?backward?pass batch_size?=?128#?number?of?subprocesses?used?for?data?loading #?Normally?do?not?use?it?if?your?os?is?windows num_workers?=?2train_dataset?=?datasets.CIFAR10('./data/CIFAR10/',?train?=?True,?download?=?True,?transform?=?transform_train)train_loader?=?DataLoader(train_dataset,?batch_size?=?batch_size,?shuffle?=?True,?num_workers?=?num_workers)val_dataset?=?datasets.CIFAR10('./data/CIFAR10',?train?=?True,?transform?=?transform_test)val_loader?=?DataLoader(val_dataset,?batch_size?=?batch_size,?shuffle?=?False,?num_workers?=?num_workers)test_dataset?=?datasets.CIFAR10('./data/CIFAR10',?train?=?False,?transform?=?transform_test)test_loader?=?DataLoader(test_dataset,?batch_size?=?batch_size,?shuffle?=?False,?num_workers?=?num_workers)#?declare?classes?in?CIFAR10 classes?=?('plane',?'car',?'bird',?'cat',?'deer',?'dog',?'frog',?'horse',?'ship',?'truck')之前的transform ’只是進(jìn)行了縮放和歸一,在這里添加RandomCrop和RandomHorizontalFlip
#?define?a?transform?to?normalize?the?datatransform_train?=?transforms.Compose([transforms.RandomCrop(32,?padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),?#?converting?images?to?tensortransforms.Normalize(mean?=?(0.5,?0.5,?0.5),?std?=?(0.5,?0.5,?0.5))?#?if?the?image?dataset?is?black?and?white?image,?there?can?be?just?one?number.? ])transform_test?=?transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean?=?(0.5,?0.5,?0.5),?std?=?(0.5,?0.5,?0.5)) ])可視化具體的圖像
#?function?that?will?be?used?for?visualizing?the?datadef?imshow(img):img?=?img?/?2?+?0.5??#?unnormalizeplt.imshow(np.transpose(img,?(1,?2,?0)))??#?convert?from?Tensor?image#?obtain?one?batch?of?imges?from?train?dataset dataiter?=?iter(train_loader) images,?labels?=?dataiter.next() images?=?images.numpy()?#?convert?images?to?numpy?for?display#?plot?the?images?in?one?batch?with?the?corresponding?labels fig?=?plt.figure(figsize?=?(25,?4))#?display?images for?idx?in?np.arange(10):ax?=?fig.add_subplot(1,?10,?idx+1,?xticks=[],?yticks=[])imshow(images[idx])ax.set_title(classes[labels[idx]])建立常見的CNN模型
#?define?the?CNN?architectureclass?CNN(nn.Module):def?__init__(self):super(CNN,?self).__init__()self.main?=?nn.Sequential(#?3x32x32nn.Conv2d(in_channels?=?3,?out_channels?=?32,?kernel_size?=?3,?padding?=?1),?#?3x32x32?(O?=?(N+2P-F/S)+1)nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size?=?2,?stride?=?2),?#?32x16x16nn.BatchNorm2d(32),nn.Conv2d(32,?64,?kernel_size?=?3,?padding?=?1),?#?32x16x16nn.ReLU(inplace=True),nn.MaxPool2d(2,?2),?#?64x8x8nn.BatchNorm2d(64),nn.Conv2d(64,?128,?3,?padding?=?1),?#?64x8x8nn.ReLU(inplace=True),nn.MaxPool2d(2,?2),?#?128x4x4nn.BatchNorm2d(128),)self.fc?=?nn.Sequential(nn.Linear(128*4*4,?1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024,?256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256,?10))def?forward(self,?x):#?Conv?and?Poolilng?layersx?=?self.main(x)#?Flatten?before?Fully?Connected?layersx?=?x.view(-1,?128*4*4)?#?Fully?Connected?Layerx?=?self.fc(x)return?xcnn?=?CNN().to(device) cnntorch.nn.CrossEntropyLoss對(duì)輸出概率介于0和1之間的分類模型進(jìn)行分類。
訓(xùn)練模型
#?超參數(shù):Hyper Parameters learning_rate?=?0.001 train_losses?=?[] val_losses?=?[]#?Loss?function?and?Optimizer criterion?=?nn.CrossEntropyLoss() optimizer?=?optim.Adam(cnn.parameters(),?lr?=?learning_rate)#?define?train?function?that?trains?the?model?using?a?CIFAR10?datasetdef?train(model,?epoch,?num_epochs):model.train()total_batch?=?len(train_dataset)?//?batch_sizefor?i,?(images,?labels)?in?enumerate(train_loader):X?=?images.to(device)Y?=?labels.to(device)###?forward?pass?and?loss?calculation#?forward?passpred?=?model(X)#c?alculation??of?loss?valuecost?=?criterion(pred,?Y)###?backward?pass?and?optimization#?gradient?initializationoptimizer.zero_grad()#?backward?passcost.backward()#?parameter?updateoptimizer.step()#?training?statsif?(i+1)?%?100?==?0:print('Train,?Epoch?[%d/%d],?lter?[%d/%d],?Loss:?%.4f'?%?(epoch+1,?num_epochs,?i+1,?total_batch,?np.average(train_losses)))train_losses.append(cost.item())n#?def?the?validation?function?that?validates?the?model?using?CIFAR10?datasetdef?validation(model,?epoch,?num_epochs):model.eval()total_batch?=?len(val_dataset)?//?batch_sizefor?i,?(images,?labels)?in?enumerate(val_loader):X?=?images.to(device)Y?=?labels.to(device)with?torch.no_grad():pred?=?model(X)cost?=?criterion(pred,?Y)if?(i+1)?%?100?==?0:print("Validation,?Epoch?[%d/%d],?lter?[%d/%d],?Loss:?%.4f"%?(epoch+1,?num_epochs,?i+1,?total_batch,?np.average(val_losses)))val_losses.append(cost.item())def?plot_losses(train_losses,?val_losses):plt.figure(figsize=(5,?5))plt.plot(train_losses,?label='Train',?alpha=0.5)plt.plot(val_losses,?label='Validation',?alpha=0.5)plt.xlabel('Epochs')plt.ylabel('Losses')plt.legend()plt.grid(b=True)plt.title('CIFAR?10?Train/Val?Losses?Over?Epoch')plt.show()num_epochs?=?20 for?epoch?in?range(num_epochs):train(cnn,?epoch,?num_epochs)validation(cnn,?epoch,?num_epochs)torch.save(cnn.state_dict(),?'./data/Tutorial_3_CNN_Epoch_{}.pkl'.format(epoch+1))plot_losses(train_losses,?val_losses)測(cè)試模型
def?test(model):#?declare?that?the?model?is?about?to?evaluatemodel.eval()correct?=?0total?=?0with?torch.no_grad():for?images,?labels?in?test_dataset:images?=?images.unsqueeze(0).to(device)#?forward?passoutputs?=?model(images)_,?predicted?=?torch.max(outputs.data,?1)total?+=?1correct?+=?(predicted?==?labels).sum().item()print("Accuracy?of?Test?Images:?%f?%%"?%?(100?*?float(correct)?/?total))經(jīng)過圖像數(shù)據(jù)增強(qiáng)。模型從60提升到了84。
測(cè)試模型在哪些類上表現(xiàn)良好,
class_correct?=?list(0.?for?i?in?range(10)) class_total?=?list(0.?for?i?in?range(10))with?torch.no_grad():for?data?in?test_loader:images,?labels?=?dataimages?=?images.to(device)labels?=?labels.to(device)outputs?=?cnn(images)_,?predicted?=?torch.max(outputs,?1)c?=?(predicted?==?labels).squeeze()for?i?in?range(4):label?=?labels[i]class_correct[label]?+=?c[i].item()class_total[label]?+=?1for?i?in?range(10):print('Accuracy?of?%5s?:?%2d?%%'?%?(classes[i],?100?*?class_correct[i]?/?class_total[i])) 往期精彩回顧適合初學(xué)者入門人工智能的路線及資料下載機(jī)器學(xué)習(xí)及深度學(xué)習(xí)筆記等資料打印機(jī)器學(xué)習(xí)在線手冊(cè)深度學(xué)習(xí)筆記專輯《統(tǒng)計(jì)學(xué)習(xí)方法》的代碼復(fù)現(xiàn)專輯 AI基礎(chǔ)下載機(jī)器學(xué)習(xí)的數(shù)學(xué)基礎(chǔ)專輯黃海廣老師《機(jī)器學(xué)習(xí)課程》課件合集 本站qq群851320808,加入微信群請(qǐng)掃碼:總結(jié)
以上是生活随笔為你收集整理的【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: RTSP/RTMP播放端录像不可忽视的几
- 下一篇: 【深度学习】卷积神经网络(CNN)详解