MXNet中x.grad源码追溯
生活随笔
收集整理的這篇文章主要介紹了
MXNet中x.grad源码追溯
小編覺得挺不錯的,現在分享給大家,幫大家做個參考.
Python測試代碼如https://zh.gluon.ai/chapter_prerequisite/autograd.html
本文追溯x.grad這一行代碼的調用
grad調用的是函數MXNDArrayGetGrad,/usr/local/lib/python3.7/dist-packages/mxnet-1.5.0-py3.7.egg/mxnet/ndarray/ndarray.py
MXNDArrayGetGrad的源碼依舊是在文件src/c_api/c_api.cc中,
NDArray ret = arr->grad();
ret就是獲取到的梯度
這里grad的源碼文件為src/ndarray/ndarray.cc,
Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);return info.out_grads[0];
這里Imperative::AGInfo::Get的源碼文件為?include/mxnet/imperative.h
return dmlc::get<AGInfo>(node->info);
這里get的源碼文件為3rdparty/dmlc-core/include/dmlc/any.h
return *any::TypeInfo<T>::get_ptr(&(src.data_));
這個get_ptr調用的是同文件中的如下代碼:
template<typename T>
class any::TypeOnHeap {public:inline static T* get_ptr(any::Data* data) {return static_cast<T*>(data->pheap);}
回到上面的代碼,那個entry_是NDArrary類的一個對象:
/*! \brief node entry for autograd */nnvm::NodeEntry entry_;
NodeEntry 源碼文件為include/nnvm/node.h,
#大體來講,梯度就是arr->entry_.node->info.data_.pheap;
總結
以上是生活随笔為你收集整理的MXNet中x.grad源码追溯的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: MXNET源码中TShape值的获取和打
- 下一篇: mxnet 中的 DepthwiseCo