理解Pytorch的dim

# Time    : 2022.07.06 上午 10:33
# Author  : Vandaci(cnfendaki@qq.com)
# File    : learning_tensor_dim.py
# Project : LearningPytorch
import torch
import torch.nn as nn

if __name__ == '__main__':
    a = torch.tensor([[1., 2.], [3., 4.]])
    # dim=2 shape=[2,2] (row,col) dim就是沿着row和col的方向,
    # 沿着row的方向就是每列,沿着col的方向就是每行
    smax = nn.Softmax(dim=0)
    o = smax(a)
    print(o)
    '''输出结果
    tensor([[0.1192, 0.1192],
        [0.8808, 0.8808]])
    '''
    smax = nn.Softmax(dim=1)
    o = smax(a)
    print(o)
    '''输出结果
    tensor([[0.2689, 0.7311],
            [0.2689, 0.7311]])
    '''
    # 对于dim>=3,亦可照此推断
    b = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8.]]])
    smax = nn.Softmax(dim=0)  # 沿着batch的方向:1对应5
    o = smax(b)
    print(o)
    '''输出结果
    tensor([[[0.0180, 0.0180],
         [0.0180, 0.0180]],

        [[0.9820, 0.9820],
         [0.9820, 0.9820]]])
    '''
    pass
posted @ 2022-07-06 10:42  Vandaci  阅读(73)  评论(0)    收藏  举报