Mxnet ConvLSTM使用方法

1.Conv2DLSTMCell源代码

class ConvLSTMCell(BaseConvRNNCell):
    """Convolutional LSTM network cell.

    Reference:
        Xingjian et al. NIPS2015

    Parameters
    ----------
    input_shape : tuple of int
        Shape of input in single timestep.
    num_hidden : int
        Number of units in output symbol.
    h2h_kernel : tuple of int, default (3, 3)
        Kernel of Convolution operator in state-to-state transitions.
    h2h_dilate : tuple of int, default (1, 1)
        Dilation of Convolution operator in state-to-state transitions.
    i2h_kernel : tuple of int, default (3, 3)
        Kernel of Convolution operator in input-to-state transitions.
    i2h_stride : tuple of int, default (1, 1)
        Stride of Convolution operator in input-to-state transitions.
    i2h_pad : tuple of int, default (1, 1)
        Pad of Convolution operator in input-to-state transitions.
    i2h_dilate : tuple of int, default (1, 1)
        Dilation of Convolution operator in input-to-state transitions.
    i2h_weight_initializer : str or Initializer
        Initializer for the input weights matrix, used for the convolution
        transformation of the inputs.
    h2h_weight_initializer : str or Initializer
        Initializer for the recurrent weights matrix, used for the convolution
        transformation of the recurrent state.
    i2h_bias_initializer : str or Initializer, default zeros
        Initializer for the bias vector.
    h2h_bias_initializer : str or Initializer, default zeros
        Initializer for the bias vector.
    activation : str or Symbol
        default functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2)
        Type of activation function.
    prefix : str, default 'ConvLSTM_'
        Prefix for name of layers (and name of weight if params is None).
    params : RNNParams, default None
        Container for weight sharing between cells. Created if None.
    conv_layout : str, , default 'NCHW'
        Layout of ConvolutionOp
    """
    def __init__(self, input_shape, num_hidden,
                 h2h_kernel=(3, 3), h2h_dilate=(1, 1),
                 i2h_kernel=(3, 3), i2h_stride=(1, 1),
                 i2h_pad=(1, 1), i2h_dilate=(1, 1),
                 i2h_weight_initializer=None, h2h_weight_initializer=None,
                 i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                 activation=functools.partial(symbol.LeakyReLU, act_type='leaky', slope=0.2),
                 prefix='ConvLSTM_', params=None,
                 conv_layout='NCHW'):
        super(ConvLSTMCell, self).__init__(input_shape=input_shape, num_hidden=num_hidden,
                                           h2h_kernel=h2h_kernel, h2h_dilate=h2h_dilate,
                                           i2h_kernel=i2h_kernel, i2h_stride=i2h_stride,
                                           i2h_pad=i2h_pad, i2h_dilate=i2h_dilate,
                                           i2h_weight_initializer=i2h_weight_initializer,
                                           h2h_weight_initializer=h2h_weight_initializer,
                                           i2h_bias_initializer=i2h_bias_initializer,
                                           h2h_bias_initializer=h2h_bias_initializer,
                                           activation=activation, prefix=prefix,
                                           params=params, conv_layout=conv_layout)

    @property
    def _gate_names(self):
        return ['_i', '_f', '_c', '_o']

    def __call__(self, inputs, states):
        self._counter += 1
        name = '%st%d_'%(self._prefix, self._counter)
        i2h, h2h = self._conv_forward(inputs, states, name)
        gates = i2h + h2h
        slice_gates = symbol.SliceChannel(gates, num_outputs=4, axis=self._conv_layout.find('C'),
                                          name="%sslice"%name)
        in_gate = symbol.Activation(slice_gates[0], act_type="sigmoid",
                                    name='%si'%name)
        forget_gate = symbol.Activation(slice_gates[1], act_type="sigmoid",
                                        name='%sf'%name)
        in_transform = self._get_activation(slice_gates[2], self._activation,
                                            name='%sc'%name)
        out_gate = symbol.Activation(slice_gates[3], act_type="sigmoid",
                                     name='%so'%name)
        next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform,
                                        name='%sstate'%name)
        next_h = symbol._internal._mul(out_gate, self._get_activation(next_c, self._activation),
                                       name='%sout'%name)

        return next_h, [next_h, next_c]

    @property
    def state_info(self):
        return [{'shape': self._state_shape, '__layout__': self._conv_layout},
                {'shape': self._state_shape, '__layout__': self._conv_layout}]

2.Conv2DRNNCell使用方法

import mxnet as mx
from mxnet import gluon, nd, autograd
nbatch = 10
nfilters = 12
shape = [nbatch,nfilters,64,64]
xx = nd.random_uniform(shape=shape)
# single layer
net  = gluon.contrib.rnn.Conv2DRNNCell(input_shape=[nfilters,64,64],
                                      hidden_channels=nfilters,
                                      i2h_kernel=(3,3), 
                                      h2h_kernel =(3,3))
init_state = net.begin_state(batch_size=nbatch)
state = init_state
with autograd.record():
    out, state = net(xx,state)
print (out.shape)
# output:(10, 12, 62, 62)
print (state[0].shape)
# output:(10, 12, 62, 62)

3. Conv2DLSTMCell使用方法

class BottleNeckLSTM(gluon.nn.HybridBlock):
    def __init__(self, input_shape, hidden_channels):
        super(BottleNeckLSTM, self).__init__()
        self.input_shape = input_shape
        self.lstm_out = []
        self.state = self.net.begin_state(batch_size=hidden_channels, func=mx.symbol.zeros)
        self.net = gluon.contrib.rnn.Conv2DLSTMCell(input_shape=input_shape,
                                                    hidden_channels=hidden_channels,
                                                    i2h_pad=(1, 1),
                                                    i2h_kernel=(3, 3),
                                                    h2h_kernel=(3, 3))

    def hybrid_forward(self, F, x, *args, **kwargs):
        for i in range(64):
            feature_slice = F.slice(x, (None, None, i, None, None), (None, None, i+1, None, None))
            feature_slice = F.reshape(feature_slice, (-1, 16, 64, 64))
            lstm_tmp, self.state = self.net(feature_slice, self.state)
            self.lstm_out.append(lstm_tmp)
        lstm_out = F.concat(*self.lstm_out, dim=1)
        return lstm_out
posted @ 2021-12-15 14:19  while(1){happiness;}  阅读(112)  评论(0)    收藏  举报