mxnet常用数据API记录(三)

 固定部分权重,不更新

比如说只更新最后一个dense层的权重:Trainer里面用下面这个而不是所有的model.collect_params()

model.collect_params(select='.*dense0')

如果想设置为只有dense0层不更新

model.collect_params('.*dense0').setattr('grad_req','null')

 

固定b_net的权重不更新:
方式一:
with autograd.record():
    a = a_net(batch_s)
    with autograd.pause(train_mode=True):
        b= b_net(batch_s_)
    loss = a - b
loss.backward()


方式二:
for _, w in b_net.collect_params().items():
    w.grad_req = 'null'

 

posted @ 2024-09-21 13:15  silence_cho  阅读(9)  评论(0)    收藏  举报