mxnet autograd.record()源码分析
引言
正在做mxnet框架下,使用deeplesion数据集yolov3的病灶检测的项目,训练过程中产生了access violated reading 0xFFFFF...的错误,经过google又说是应为backward()在autograd.record()作用域的原因,如下所示:所以想研究一下mxnet自动求导的过程,autograd.record()是如何在底层控制梯度计算的。
with autograd.record(mode == "train"):
losses.backward()
1. autograd.record()
这里使用了python的一个特性上下文管理器,record函数调用了函数:
return _RecordingStateScope(True, train_mode)
2._RecordingStateScope()
class _RecordingStateScope(object):
"""Scope for managing training state.
Example::
with _RecordingStateScope(True, True):
y = model(x)
backward([y])
"""
def __init__(self, is_record, train_mode): #pylint: disable=redefined-outer-name
self._enter_is_record = is_record
self._enter_train_mode = train_mode
self._prev_is_record = None
self._prev_train_mode = None
def __enter__(self):
if self._enter_is_record is not None:
self._prev_is_record = set_recording(self._enter_is_record)
if self._enter_train_mode is not None:
self._prev_train_mode = set_training(self._enter_train_mode)
def __exit__(self, ptype, value, trace):
if self._enter_is_record is not None and self._prev_is_record != self._enter_is_record:
set_recording(self._prev_is_record)
if self._enter_train_mode is not None and self._prev_train_mode != self._enter_train_mode:
set_training(self._prev_train_mode)
这里enter和exit分别是进入推出上下文管理器的函数,当训练时先进入enter函数,self._enter_is_record默认为True,所以调用set_recording(True)函数
def set_recording(is_recording): #pylint: disable=redefined-outer-name
"""Set status to recording/not recording. When recording, graph will be constructed
for gradient computation.
Parameters
----------
is_recording: bool
Returns
-------
previous state before this set.
"""
prev = ctypes.c_int()
check_call(_LIB.MXAutogradSetIsRecording(
ctypes.c_int(is_recording), ctypes.byref(prev)))
return bool(prev.value)
以上可知函数最终调用的是libmxnet.dll(源码编译生成的动态链接库)里的MXAutogradSetIsRecording函数。
int MXAutogradSetIsRecording(int is_recording, int* prev) {
API_BEGIN();
*prev = Imperative::Get()->set_is_recording(static_cast<bool>(is_recording));
API_END();
}
bool set_is_recording(bool is_recording) {
bool old = is_recording_;
is_recording_ = is_recording;
return old;
}
可以看到当使用record()函数后,最终改变的是is_recording_ 的值,is_recording_定义如下:
namespace mxnet {
/*! \brief runtime functions for NDArray */
class Imperative {
public:
/*! \brief */
class AGInfo {
public:
Context ctx;
OpReqType grad_req;
OpStatePtr state;
std::vector<NDArray> outputs;
std::vector<NDArray> out_grads;
bool fresh_out_grad;
.
.
.
std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
const std::vector<NDArray*>& variables,
bool is_train, bool retain_graph,
bool create_graph);
/*! \return AutogradRuntime singleton */
static Imperative* Get();
private:
friend class NDArray;
#if DMLC_CXX11_THREAD_LOCAL
static thread_local bool is_train_;
static thread_local bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static thread_local bool is_np_shape_;
#else
static MX_THREAD_LOCAL bool is_train_;
static MX_THREAD_LOCAL bool is_recording_;
// TOOD(junwu): Added numpy compatibility switch for backward compatibility.
// Delete it in the next major release.
static MX_THREAD_LOCAL bool is_np_shape_;
#endif
};
static thread_local bool is_recording_ 是由两个关键字static,thread_local修饰的,static修饰的类成员,当改变is_recording_ 的值时,该类的所有对象的is_recording_ 属性均会被改变。thread_local修饰的变量,其生命周期时线程的生命周期;所以static的ThreadLocal变量是一个与线程相关的静态变量,即一个线程内,static变量是被各个实例共同引用的,但是不同线程内,static变量是隔开的。
实验也证明在autograd.record()上下文管理器范围的计算时多线程的,在我们调用record()后修改了is_recording_ 和is_training_变量的值。
那么mxnet时如何通过autograd.record()控制反向传播的导数的计算呢?
训练中我们进行梯度计算的代码如下:
a = nd.random.normal(shape=1)
a.attach_grad()
with autograd.record():
c = f(a)
c.backward()
梯度的计算是通过Ndarry类调用backward()实现的,而从Imperative类中我们可以看到,ndarray时Imperative友元类(Imperative同样是ndarray的友元类,互为友元类可以互相调用方法和查看属性),自动求梯度中使用的Backward也是类Imperative中的方法。因此我们在调用autograd.record()是相应的Imperative类中的属性is_record_ 及is_training_ 属性被相应的置位。在with上下文管理器作用域内生成的ndarry类对象,因为is_record_ ,is_training_ 属性是友元类的static属性,如果is_record_ is_ 置1那么在with作用域下的所有ndarry变量都被标记为了已经被记录的状态,那么这些被标记的变量就参与到了反向传播和梯度计算。在作用域之外的变量就不会参与其中,我们亦可以使用with autograd.pause()来让某些计算过程不涉及反传梯度计算。
浙公网安备 33010602011771号