Fork me on CSDN

Batch Normalization Code

def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not torch.is_grad_enabled():#In prediction mode
        X_hat = (X-moving_mean)/torch.sqrt(moving_var+eps)
    else:
        assert len(X.shape) in (2,4)
        if len(X.shape) == 2:
            mean = X.mean(dim=0)
            var = ((X-mean)**2).mean(dim=0)
        else:
            mean = X.mean(dim=(0,2,3),keepdim=True)
            var  = ((X-mean)**2).mean(dim=(0,2,3),keepdim=True)
        X_hat = (X-mean)/torch.sqrt(var+eps)
        moving_mean = momentum*moving_mean + (1.0-momentum)*mean
        moving_var = momentum*moving_var + (1.0-momentum)*var
    Y = gamma*X-hat+beta
    return Y,moving_mean,moving_var

 

posted @ 2021-12-30 15:25  追风赶月的少年  阅读(49)  评论(0编辑  收藏  举报