pytorch中深度拷贝_pytorch:对比clone、detach以及copy_等张量复制操作
pytorch提供了clone、detach、copy_和new_tensor等多種張量的復(fù)制操作,尤其前兩者在深度學(xué)習(xí)的網(wǎng)絡(luò)架構(gòu)中經(jīng)常被使用,本文旨在對比這些操作的差別。
1. clone
返回一個和源張量同shape、dtype和device的張量,與源張量不共享數(shù)據(jù)內(nèi)存,但提供梯度的回溯。
下面,通過例子來詳細(xì)說明:
示例:
(1)定義
import torch
a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)
a_ = a.clone()
print(a_) # tensor(1., device='cuda:0', dtype=torch.float64, grad_fn=)
注意:grad_fn=,說明clone后的返回值是個中間variable,因此支持梯度的回溯。因此,clone操作在一定程度上可以視為是一個identity-mapping函數(shù)。
(2)梯度的回溯
clone作為一個中間variable,會將梯度傳給源張量進行疊加。
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
a_ = a.clone()
z = a_ * 3
y.backward()
print(a.grad) # 2
z.backward()
print(a_.grad) # None. 中間variable,無grad
print(a.grad) # 5. a_的梯度會傳遞回給a,因此2+3=5
但若源張量的require_grad=False,而clone后的張量require_grad=True,顯然此時不存在張量回溯現(xiàn)象,clone后的張量可以求導(dǎo)。
import torch
a = torch.tensor(1.0)
a_ = a.clone()
a_.requires_grad_()
y = a_ ** 2
y.backward()
print(a.grad) # None
print(a_.grad) # 2. 可得到導(dǎo)數(shù)
(3)張量數(shù)據(jù)非共享
import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.clone()
a.data *= 3
a_ += 1
print(a) # tensor(3., requires_grad=True)
print(a_) # tensor(2., grad_fn=). 注意grad_fn的變化
綜上論述,clone操作在不共享數(shù)據(jù)內(nèi)存的同時支持梯度回溯,所以常用在神經(jīng)網(wǎng)絡(luò)中某個單元需要重復(fù)使用的場景下。
2. detach
detach的機制則與clone完全不同,即返回一個和源張量同shape、dtype和device的張量,與源張量共享數(shù)據(jù)內(nèi)存,但不提供梯度計算,即requires_grad=False,因此脫離計算圖。
同樣,通過例子來詳細(xì)說明:
(1)定義
import torch
a = torch.tensor(1.0, requires_grad=True, device="cuda", dtype=torch.float64)
a_ = a.detach()
print(a_) # tensor(1., device='cuda:0', dtype=torch.float64)
(2)脫離原計算圖
import torch
a = torch.tensor(1.0, requires_grad=True)
y = a ** 2
a_ = a.detach()
print(a_.grad) # None,requires_grad=False
a_.requires_grad_() # 強制其requires_grad=True,從而支持求導(dǎo)
z = a_ * 3
y.backward()
z.backward()
print(a.grad) # 2,與a_無關(guān)系
print(a_.grad) #
可見,detach后的張量,即使重新定義requires_grad=True,也與源張量的梯度沒有關(guān)系。
(3)共享張量數(shù)據(jù)內(nèi)存
import torch
a = torch.tensor(1.0, requires_grad=True)
a_ = a.detach()
print(a) # tensor(1., requires_grad=True)
print(a_) # tensor(1.)
a_ += 1
print(a) # tensor(2., requires_grad=True)
print(a_) # tensor(2.)
a.data *= 2
print(a) # tensor(4., requires_grad=True)
print(a_) # tensor(4.)
綜上論述,detach操作在共享數(shù)據(jù)內(nèi)存的脫離計算圖,所以常用在神經(jīng)網(wǎng)絡(luò)中僅要利用張量數(shù)值,而不需要追蹤導(dǎo)數(shù)的場景下。
3. clone和detach聯(lián)合使用
clone提供了非數(shù)據(jù)共享的梯度追溯功能,而detach又“舍棄”了梯度功能,因此clone和detach意味著著只做簡單的數(shù)據(jù)復(fù)制,既不數(shù)據(jù)共享,也不對梯度共享,從此兩個張量無關(guān)聯(lián)。
置于是先clone還是先detach,其返回值一樣,一般采用tensor.clone().detach()。
4. new_tensor
new_tensor可以將源張量中的數(shù)據(jù)復(fù)制到目標(biāo)張量(數(shù)據(jù)不共享),同時提供了更細(xì)致的device、dtype和requires_grad屬性控制:
new_tensor(data, dtype=None, device=None, requires_grad=False)
注意:其默認(rèn)參數(shù)下的操作等同于.clone().detach(),而requires_grad=True時的效果相當(dāng)于.clone().detach()requires_grad_(True)。上面兩種情況都推薦使用后者。
5. copy_
copy_同樣將源張量中的數(shù)據(jù)復(fù)制到目標(biāo)張量(數(shù)據(jù)不共享),其device、dtype和requires_grad一般都保留目標(biāo)張量的設(shè)定,僅僅進行數(shù)據(jù)復(fù)制,同時其支持broadcast操作。
a = torch.tensor([[1,2,3], [4,5,6]], device="cuda")
b = torch.tensor([7.0,8.0,9.0], requires_grad=True)
a.copy_(b)
print(a) # tensor([[7, 8, 9], [7, 8, 9]], device='cuda:0')
【Ref】:
總結(jié)
以上是生活随笔為你收集整理的pytorch中深度拷贝_pytorch:对比clone、detach以及copy_等张量复制操作的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 怎么设置php.ini允许sql语句插入
- 下一篇: win2008无法用计算机名共享,Win