MXNet 中的 hybird_forward 的一个使用技巧

from mxnet.gluon import nn
from mxnet import nd
class SliceLike(nn.HybridBlock):
    def __init__(self, xs, **kwargs):
        super().__init__(**kwargs)
        self.xs = self.params.get_constant('x_', xs)
        self.ys = self.params.get('y_', shape=xs.shape)
        self.A = 'sl'

    def hybrid_forward(self, F, x, xs, ys):
        print(self._reg_params)
        a = F.slice_like(xs, x * 0, axes=(1))
        return a.reshape((1, -1, 4))

hybrid_forward 函数的参数如下形式:(self, F, x, *args, **kwargs)

下面解释一下 (self, F, x, xs, ys):首先 self._reg_params 会收集 self.params.get_constant 或者 self.params.get 创建的参数字典,然后直接传入 hybrid_forward 中:

xs = nd.arange(6e4).reshape((10, 10))
sx = SliceLike(xs)
sx.initialize()
y = nd.zeros((1, 1, 2, 3))
sx(y)
{'xs': Constant slicelike12_x_ (shape=(10, 10), dtype=<class 'numpy.float32'>), 'ys': Parameter slicelike12_y_ (shape=(10, 10), dtype=<class 'numpy.float32'>)}






[[[ 0. 10. 20. 30.]
  [40. 50. 60. 70.]]]
<NDArray 1x2x4 @cpu(0)>

posted @ 2019-03-27 22:37  xinet  阅读(1391)  评论(0编辑  收藏  举报