PyTorch基础-使用LSTM神经网络实现手写数据集识别-08
生活随笔
收集整理的這篇文章主要介紹了
PyTorch基础-使用LSTM神经网络实现手写数据集识别-08
小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,幫大家做個(gè)參考.
import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 訓(xùn)練集
train_data = datasets.MNIST(root="./", # 存放位置train = True, # 載入訓(xùn)練集transform=transforms.ToTensor(), # 把數(shù)據(jù)變成tensor類型download = True # 下載)
# 測(cè)試集
test_data = datasets.MNIST(root="./",train = False,transform=transforms.ToTensor(),download = True)
# 批次大小
batch_size = 64
# 裝載訓(xùn)練集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 裝載測(cè)試集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):inputs,labels = dataprint(inputs.shape)print(labels.shape)break
# 定義網(wǎng)絡(luò)結(jié)構(gòu)
class LSTM(nn.Module):def __init__(self):super(LSTM,self).__init__()# 初始化self.lstm = torch.nn.LSTM(input_size = 28, # 表示輸入特征的大小hidden_size = 64, # 表示lstm模塊的數(shù)量num_layers = 1, # 表示lstm隱藏層的層數(shù)batch_first = True # lstm默認(rèn)格式input(seq_len,batch,feature)等于True表示input和output變成(batch,seq_len,feature))self.out = torch.nn.Linear(in_features=64,out_features=10)self.softmax = torch.nn.Softmax(dim=1)def forward(self,x):# (batch,seq_len,feature)x = x.view(-1,28,28)# output:(batch,seq_len,hidden_size)包含每個(gè)序列的輸出結(jié)果# 雖然lstm的batch_first為True,但是h_n,c_n的第0個(gè)維度還是num_layers# h_n :[num_layers,batch,hidden_size]只包含最后一個(gè)序列的輸出結(jié)果# c_n:[num_layers,batch,hidden_size]只包含最后一個(gè)序列的輸出結(jié)果output,(h_n,c_n) = self.lstm(x)output_in_last_timestep = h_n[-1,:,:]x = self.out(output_in_last_timestep)x = self.softmax(x)return x
# 定義模型
model = LSTM()
# 定義代價(jià)函數(shù)
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定義優(yōu)化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 隨機(jī)梯度下降
# 定義模型訓(xùn)練和測(cè)試的方法
def train():# 模型的訓(xùn)練狀態(tài)model.train()for i,data in enumerate(train_loader):# 獲得一個(gè)批次的數(shù)據(jù)和標(biāo)簽inputs,labels = data# 獲得模型預(yù)測(cè)結(jié)果(64,10)out = model(inputs)# 交叉熵代價(jià)函數(shù)out(batch,C:類別的數(shù)量),labels(batch)loss = mse_loss(out,labels)# 梯度清零optimizer.zero_grad()# 計(jì)算梯度loss.backward()# 修改權(quán)值optimizer.step()def test():# 模型的測(cè)試狀態(tài)model.eval()correct = 0 # 測(cè)試集準(zhǔn)確率for i,data in enumerate(test_loader):# 獲得一個(gè)批次的數(shù)據(jù)和標(biāo)簽inputs,labels = data# 獲得模型預(yù)測(cè)結(jié)果(64,10)out = model(inputs)# 獲得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 預(yù)測(cè)正確的數(shù)量correct += (predicted==labels).sum()print("Test acc:{0}".format(correct.item()/len(test_data)))correct = 0for i,data in enumerate(train_loader): # 訓(xùn)練集準(zhǔn)確率# 獲得一個(gè)批次的數(shù)據(jù)和標(biāo)簽inputs,labels = data# 獲得模型預(yù)測(cè)結(jié)果(64,10)out = model(inputs)# 獲得最大值,以及最大值所在的位置_,predicted = torch.max(out,1)# 預(yù)測(cè)正確的數(shù)量correct += (predicted==labels).sum()print("Train acc:{0}".format(correct.item()/len(train_data)))
# 訓(xùn)練
for epoch in range(10):print("epoch:",epoch)train()test()
總結(jié)
以上是生活随笔為你收集整理的PyTorch基础-使用LSTM神经网络实现手写数据集识别-08的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: PyTorch基础-使用卷积神经网络CN
- 下一篇: PyTorch基础-模型的保存和加载-0