单精度和半精度混合训练
單精度和半精度混合訓(xùn)練
概述
混合精度訓(xùn)練方法,通過(guò)混合使用單精度和半精度數(shù)據(jù)格式,加速深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練的過(guò)程,同時(shí)保持了單精度訓(xùn)練所能達(dá)到的網(wǎng)絡(luò)精度。混合精度訓(xùn)練能夠加速計(jì)算過(guò)程,同時(shí)減少內(nèi)存使用和存取,并使得在特定的硬件上可以訓(xùn)練更大的模型或batch size。
對(duì)于FP16的算子,若給定的數(shù)據(jù)類型是FP32,MindSpore框架的后端會(huì)進(jìn)行降精度處理。用戶可以開(kāi)啟INFO日志,并通過(guò)搜索關(guān)鍵字“Reduce precision”查看降精度處理的算子。
計(jì)算流程
MindSpore混合精度典型的計(jì)算流程如下圖所示:
- 參數(shù)以FP32存儲(chǔ);
- 正向計(jì)算過(guò)程中,遇到FP16算子,需要把算子輸入和參數(shù)從FP32 cast成FP16進(jìn)行計(jì)算;
- 將Loss層設(shè)置為FP32進(jìn)行計(jì)算;
- 反向計(jì)算過(guò)程中,首先乘以Loss Scale值,避免反向梯度過(guò)小而產(chǎn)生下溢;
- FP16參數(shù)參與梯度計(jì)算,其結(jié)果將被cast回FP32;
- 除以Loss scale值,還原被放大的梯度;
- 判斷梯度是否存在溢出,如果溢出則跳過(guò)更新,否則優(yōu)化器以FP32對(duì)原始參數(shù)進(jìn)行更新。
本文通過(guò)自動(dòng)混合精度和手動(dòng)混合精度的樣例來(lái)講解計(jì)算流程。
自動(dòng)混合精度
使用自動(dòng)混合精度,需要調(diào)用相應(yīng)的接口,將待訓(xùn)練網(wǎng)絡(luò)和優(yōu)化器作為輸入傳進(jìn)去;該接口會(huì)將整張網(wǎng)絡(luò)的算子轉(zhuǎn)換成FP16算子(除BatchNorm算子和Loss涉及到的算子外)。可以使用amp接口和Model接口兩種方式實(shí)現(xiàn)混合精度。
使用amp接口具體的實(shí)現(xiàn)步驟為: - 引入MindSpore的混合精度的接口amp;
- 定義網(wǎng)絡(luò):該步驟和普通的網(wǎng)絡(luò)定義沒(méi)有區(qū)別(無(wú)需手動(dòng)配置某個(gè)算子的精度);
- 使用amp.build_train_network接口封裝網(wǎng)絡(luò)模型、優(yōu)化器和損失函數(shù),設(shè)置level參數(shù),參考https://www.mindspore.cn/doc/api_python/zh-CN/r1.1/mindspore/mindspore.html#mindspore.build_train_network。在該步驟中,MindSpore會(huì)將有需要的算子自動(dòng)進(jìn)行類型轉(zhuǎn)換。
代碼樣例如下:
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
import mindspore.ops as ops
from mindspore.nn import Momentum
The interface of Auto_mixed precision
from mindspore import amp
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target=“Ascend”)
Define network
class Net(nn.Cell):
def init(self, input_channel, out_channel):
super(Net, self).init()
self.dense = nn.Dense(input_channel, out_channel)
self.relu = ops.ReLU()
def construct(self, x):x = self.dense(x)x = self.relu(x)return x
Initialize network
net = Net(512, 128)
Define training data, label
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([64, 128]).astype(np.float32))
Define Loss and Optimizer
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_network = amp.build_train_network(net, optimizer, loss, level=“O3”, loss_scale_manager=None)
Run training
output = train_network(predict, label)
使用Model接口具體的實(shí)現(xiàn)步驟為:
- 引入MindSpore的模型訓(xùn)練接口Model;
- 定義網(wǎng)絡(luò):該步驟和普通的網(wǎng)絡(luò)定義沒(méi)有區(qū)別(無(wú)需手動(dòng)配置某個(gè)算子的精度);
- 創(chuàng)建數(shù)據(jù)集。該步驟可參考 https://www.mindspore.cn/tutorial/training/zh-CN/r1.1/use/data_preparation.html;
- 使用Model接口封裝網(wǎng)絡(luò)模型、優(yōu)化器和損失函數(shù),設(shè)置amp_level參數(shù),參考https://www.mindspore.cn/doc/api_python/zh-CN/r1.1/mindspore/mindspore.html#mindspore.Model。在該步驟中,MindSpore會(huì)將有需要的算子自動(dòng)進(jìn)行類型轉(zhuǎn)換。
代碼樣例如下:
import numpy as np
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore import context, Model
from mindspore.common.initializer import Normal
from src.dataset import create_dataset
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target=“Ascend”)
Define network
class LeNet5(nn.Cell):
“”"
Lenet network
Args:num_class (int): Number of classes. Default: 10.num_channel (int): Number of channels. Default: 1.Returns:Tensor, output tensor
Examples:>>> LeNet(num_class=10)"""
def __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.max_pool2d(self.relu(self.conv1(x)))x = self.max_pool2d(self.relu(self.conv2(x)))x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x
create dataset
ds_train = create_dataset("/dataset/MNIST/train", 32)
Initialize network
network = LeNet5(10)
Define Loss and Optimizer
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction=“mean”)
net_opt = nn.Momentum(network.trainable_params(),learning_rate=0.01, momentum=0.9)
model = Model(network, net_loss, net_opt, metrics={“Accuracy”: Accuracy()}, amp_level=“O3”)
Run training
model.train(epoch=10, train_dataset=ds_train)
手動(dòng)混合精度
MindSpore還支持手動(dòng)混合精度。假定在網(wǎng)絡(luò)中只有一個(gè)Dense Layer要用FP32計(jì)算,其他Layer都用FP16計(jì)算。混合精度配置以Cell為粒度,Cell默認(rèn)是FP32類型。
以下是一個(gè)手動(dòng)混合精度的實(shí)現(xiàn)步驟:
- 定義網(wǎng)絡(luò):該步驟與自動(dòng)混合精度中的步驟2類似;
- 配置混合精度:通過(guò)net.to_float(mstype.float16),把該Cell及其子Cell中所有的算子都配置成FP16;然后,將模型中的dense算子手動(dòng)配置成FP32;
- 使用TrainOneStepCell封裝網(wǎng)絡(luò)模型和優(yōu)化器。
代碼樣例如下:
import numpy as np
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore import Tensor, context
import mindspore.ops as ops
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn import Momentum
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target=“Ascend”)
Define network
class Net(nn.Cell):
def init(self, input_channel, out_channel):
super(Net, self).init()
self.dense = nn.Dense(input_channel, out_channel)
self.relu = ops.ReLU()
def construct(self, x):x = self.dense(x)x = self.relu(x)return x
Initialize network
net = Net(512, 128)
Set mixing precision
net.to_float(mstype.float16)
net.dense.to_float(mstype.float32)
Define training data, label
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([64, 128]).astype(np.float32))
Define Loss and Optimizer
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
Run training
output = train_network(predict, label)
約束
使用混合精度時(shí),只能由自動(dòng)微分功能生成反向網(wǎng)絡(luò),不能由用戶自定義生成反向網(wǎng)絡(luò),否則可能會(huì)導(dǎo)致MindSpore產(chǎn)生數(shù)據(jù)格式不匹配的異常信息。
總結(jié)
以上是生活随笔為你收集整理的单精度和半精度混合训练的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。