pytorch 查看中间变量的梯度
生活随笔
收集整理的這篇文章主要介紹了
pytorch 查看中间变量的梯度
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
pytorch 為了節省顯存,在反向傳播的過程中只針對計算圖中的葉子結點(leaf variable)保留了梯度值(gradient)。但對于開發者來說,有時我們希望探測某些中間變量(intermediate variable) 的梯度來驗證我們的實現是否有誤,這個過程就需要用到 tensor的register_hook接口。一段簡單的示例代碼如下,代碼主要來自pytorch開發者的回答,筆者稍作修改使其更符合最新版的pytorch 語法(v1.2.0)。
grads = {}def save_grad(name):def hook(grad):grads[name] = gradreturn hookx = torch.randn(1, requires_grad=True) y = 3*x z = y * y# 為中間變量注冊梯度保存接口,存儲梯度時名字為 y。 y.register_hook(save_grad('y'))# 反向傳播 z.backward()# 查看 y 的梯度值 print(grads['y'])一個示例輸出是:
tensor([-1.5435])轉載于:https://www.cnblogs.com/SivilTaram/p/pytorch_intermediate_variable_gradient.html
總結
以上是生活随笔為你收集整理的pytorch 查看中间变量的梯度的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: [FY20 创新人才班 ASE] 第 1
- 下一篇: 花旗银行信用卡制卡要多久能到?卡片审核时