pytorch分布式训练(一):torch.nn.DataParallel
??本文介紹最簡單的pytorch分布式訓(xùn)練方法:使用torch.nn.DataParallel這個API來實(shí)現(xiàn)分布式訓(xùn)練。環(huán)境為單機(jī)多gpu,不妨假設(shè)有4個可用的gpu。
一、構(gòu)建方法
使用這個API實(shí)現(xiàn)分布式訓(xùn)練的步驟非常簡單,總共分為3步驟:
1、創(chuàng)建一個model,并將該model推到某個gpu上(這個gpu也將作為output_device,后面具體解釋含義),不妨假設(shè)推到第0號gpu上,
2、將數(shù)據(jù)推到output_device對應(yīng)的gpu上,
data = data.to(device)3、使用torch.nn.DataParallel這個API來在0,1,2,3四個gpu上構(gòu)建分布式模型,
model = torch.nn.DataParallel(model, device_ids=[0,1,2,3], output_device=0)然后這個model就可以像普通的單gpu上的模型一樣開始訓(xùn)練了。
二、原理詳解
2.1 原理圖
??首先通過圖來看一下這個最簡單的分布式訓(xùn)練API的工作原理,然后結(jié)合代碼詳細(xì)闡述。
將模型和數(shù)據(jù)推入output_device(也就是0號)gpu上。
0號gpu將當(dāng)前模型在其他幾個gpu上進(jìn)行復(fù)制,同步模型的parameter、buffer和modules等;將當(dāng)前batch盡可能平均的分為len(device)=4份,分別推給每一個設(shè)備,并開啟多線程分別在每個設(shè)備上進(jìn)行前向傳播,得到各自的結(jié)果,最后將各自的結(jié)果全部匯總在一起,拷貝回0號gpu。
在0號gpu進(jìn)行反向傳播和模型的參數(shù)更新,并將結(jié)果同步給其他幾個gpu,即完成了一個batch的訓(xùn)練。
2.2 代碼原理
??通過分析torch.nn.DataParallel的代碼,可以看到具體的過程,這里重點(diǎn)看一下幾個關(guān)鍵的地方。
# 繼承自nn.Module,只要實(shí)現(xiàn)__init__和forward函數(shù)即可 class DataParallel(Module):# 構(gòu)造函數(shù)里沒有什么關(guān)鍵內(nèi)容,主要是根據(jù)傳進(jìn)來的model、device_ids和output_device進(jìn)行一些變量生成def __init__(self, module, device_ids=None, output_device=None, dim=0):super(DataParallel, self).__init__()device_type = _get_available_device_type()if device_type is None:self.module = moduleself.device_ids = []returnif device_ids is None:device_ids = _get_all_device_indices()if output_device is None:output_device = device_ids[0]self.dim = dimself.module = moduleself.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))self.output_device = _get_device_index(output_device, True)self.src_device_obj = torch.device(device_type, self.device_ids[0])_check_balance(self.device_ids)if len(self.device_ids) == 1:self.module.to(self.src_device_obj)def forward(self, *inputs, **kwargs):if not self.device_ids:return self.module(*inputs, **kwargs)for t in chain(self.module.parameters(), self.module.buffers()):if t.device != self.src_device_obj:raise RuntimeError("module must have its parameters and buffers ""on device {} (device_ids[0]) but found one of ""them on device: {}".format(self.src_device_obj, t.device))inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)if len(self.device_ids) == 1:return self.module(*inputs[0], **kwargs[0])# 在每個gpu上都復(fù)制一個modelreplicas = self.replicate(self.module, self.device_ids[:len(inputs)])# 開啟多線程進(jìn)行前向傳播,得到結(jié)果outputs = self.parallel_apply(replicas, inputs, kwargs)# 將每個gpu上得到的結(jié)果都gather到0號gpu上return self.gather(outputs, self.output_device)def replicate(self, module, device_ids):return replicate(module, device_ids, not torch.is_grad_enabled())def scatter(self, inputs, kwargs, device_ids):return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)def parallel_apply(self, replicas, inputs, kwargs):return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])def gather(self, outputs, output_device):return gather(outputs, output_device, dim=self.dim)再看一下parallel_apply這個關(guān)鍵的函數(shù),
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):assert len(modules) == len(inputs)if kwargs_tup is not None:assert len(modules) == len(kwargs_tup)else:kwargs_tup = ({},) * len(modules)if devices is not None:assert len(modules) == len(devices)else:devices = [None] * len(modules)devices = list(map(lambda x: _get_device_index(x, True), devices))# 創(chuàng)建一個互斥鎖,防止前后兩個batch的數(shù)據(jù)覆蓋lock = threading.Lock()results = {}grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()# 線程的target函數(shù),實(shí)現(xiàn)每個gpu上進(jìn)行推理,其中i為gpu編號def _worker(i, module, input, kwargs, device=None):torch.set_grad_enabled(grad_enabled)if device is None:device = get_a_var(input).get_device()try:# 根據(jù)當(dāng)前gpu編號確定推理硬件環(huán)境with torch.cuda.device(device), autocast(enabled=autocast_enabled):# this also avoids accidental slicing of `input` if it is a Tensorif not isinstance(input, (list, tuple)):input = (input,)output = module(*input, **kwargs)# 鎖住賦值,防止后一個batch的數(shù)據(jù)將前一個batch的結(jié)果覆蓋with lock:results[i] = outputexcept Exception:with lock:results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device))if len(modules) > 1:# 創(chuàng)建多個線程,進(jìn)行不同gpu的前向推理threads = [threading.Thread(target=_worker,args=(i, module, input, kwargs, device))for i, (module, input, kwargs, device) inenumerate(zip(modules, inputs, kwargs_tup, devices))]for thread in threads:thread.start()for thread in threads:thread.join()else:_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])# 將不同gpu上推理的結(jié)果打包起來,后面會gather到output_device上outputs = []for i in range(len(inputs)):output = results[i]if isinstance(output, ExceptionWrapper):output.reraise()outputs.append(output)return outputs結(jié)論
??至此我們看到了torch.nn.DataParallel模塊進(jìn)行分布式訓(xùn)練的原理,數(shù)據(jù)和模型首先推入output_device對應(yīng)的gpu,然后將分成多個子batch的數(shù)據(jù)和模型分別推給其他gpu,每個gpu單獨(dú)處理各自的子batch,結(jié)果再打包回原output_device對應(yīng)的gpu算梯度和更新參數(shù),如此循環(huán)往復(fù),其本質(zhì)是一個單進(jìn)程多線程的并發(fā)程序。
??由此我們也很容易得到torch.nn.DataParallel模塊進(jìn)行分布式的缺點(diǎn),
1、每個batch的數(shù)據(jù)先分發(fā)到各gpu上,結(jié)果再打包回output_device上,在output_device一個gpu上進(jìn)行梯度計(jì)算和參數(shù)更新,再把更新同步給其他gpu上的model。其中涉及數(shù)據(jù)的來回拷貝,網(wǎng)絡(luò)通信耗時嚴(yán)重,GPU使用率低。
2、這種模式只支持單機(jī)多gpu的硬件拓?fù)浣Y(jié)構(gòu),不支持Apex的混合精度訓(xùn)練等。
3、torch.nn.DataParallel也沒有很完整的考慮到多個gpu做數(shù)據(jù)并行的一些問題,比如batchnorm,在訓(xùn)練時各個gpu上的batchnorm的mean和variance是子batch的計(jì)算結(jié)果,而不是原來整個batch的值,可能會導(dǎo)致訓(xùn)練不穩(wěn)定影響收斂等問題。
總結(jié)
以上是生活随笔為你收集整理的pytorch分布式训练(一):torch.nn.DataParallel的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 编译处理过程
- 下一篇: C语言编译过程总结详解