Pytorch定义参数用法nn.Parameter

class LayerNorm(nn.Module):                    #层归一化 
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2    

初始化:

nn.init.uniform_(self.W, 0, 1)
posted @ 2021-01-27 16:49  douzujun  阅读(1242)  评论(0编辑  收藏  举报