在显存不足时,增加batch size的方法
生活随笔
收集整理的這篇文章主要介紹了
在显存不足时,增加batch size的方法
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
問題:
如何在顯存不足的情況下,增加batch-size?
換言之,如何增加batch-size而無需擴大顯存?
思路:
將batch數據,分為多個mini-batch,對mini-batch計算loss,再求和,進行反向傳播。
這樣內存只占用mini-batch大小的數據,用時間換空間。
pytorch實現:
import torch from sklearn import metrics from torch import nn import torch.nn.functional as F from torch.utils.data import DataLoader# 簡單的TextRNN模型 class TextRNN(nn.Module):def __init__(self, num_words, num_classes, embedding_dim, hidden_dim, dropout):super(TextRNN, self).__init__()self.embed = nn.Embedding(num_embeddings=num_words + 1, embedding_dim=embedding_dim, padding_idx=num_words)self.encode = nn.GRU(embedding_dim, 200, batch_first=True, bidirectional=True)self.mlp = nn.Sequential(nn.Dropout(dropout),nn.Linear(hidden_dim * 2, num_classes))def forward(self, x, masks):x = self.embed(x)x, _ = self.encode(x)x = x.max(1)[0]x = self.mlp(x)return x# 一輪訓練 # 對每個batch的數據進行切分為幾個小mini-batch # 計算每個mini-batch的loss,進行相加 # 最終在batch上進行反向傳播操作 def train_eval(cate, loader, mini_batch_size, model, optimizer, loss_func):model.train() if cate == "train" else model.eval() # 定義模型訓練模式preds, labels, loss_sum = [], [], 0. # loss_sum只做統計操作,不進行反向傳播for i, data in enumerate(loader):# 加載一批mini-batch數據mini_loader = DataLoader(list(zip(*data)), batch_size=mini_batch_size)loss = 0. # 計算mini-batch的loss總和,進行反向傳播for j, (inputs, masks, targets) in enumerate(mini_loader):y = model(inputs, masks) # 獲取輸出loss += loss_func(y, targets) # mini-batch求loss總和# 只做統計,不進行反向傳播preds.append(y.max(dim=1)[1].data) # 統計predslabels.append(targets.data) # 統計labels# 對loss反向傳播optimizer.zero_grad()loss.backward()optimizer.step()# 統計lossloss_sum += loss.datapreds = torch.cat(preds).tolist()labels = torch.cat(labels).tolist()loss = loss_sum / len(loader)acc = metrics.accuracy_score(labels, preds) * 100return loss, acc, preds, labelsif __name__ == '__main__':# 模型參數num_words = 5000num_classes = 20embedding_dim = 300hidden_dim = 200dropout = 0.5# 數據集參數num_samples = 10000pad_len = 1000# 訓練參數batch_size = 4096mini_batch_size = 64lr = 1e-3weight_decay = 1e-6# 構造測試數據inputs = torch.randint(0, num_words + 1, (num_samples, pad_len))masks = torch.randint(0, 1, (num_samples, pad_len, 1)).float()targets = torch.randint(0, num_classes - 1, (num_samples,))word2vec = torch.rand((num_words + 1, embedding_dim)).numpy()dataset = list(zip(inputs, masks, targets))loader = DataLoader(dataset,batch_size=batch_size, # loss反向傳播的batchshuffle=True)# 模型、損失函數、優化器model = TextRNN(num_words=num_words, num_classes=num_classes,embedding_dim=embedding_dim, hidden_dim=hidden_dim,dropout=dropout)loss_func = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)# 開始訓練for epoch in range(1, 100):loss, acc, preds, labels = train_eval("train", loader, mini_batch_size, model, optimizer, loss_func)print("-" * 50)print(epoch, loss)總結
以上是生活随笔為你收集整理的在显存不足时,增加batch size的方法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: python运行提示显卡内存不足_Pyt
- 下一篇: 【显存不足解决方法】梯度累积