Pytorch 张量基础知识

1. 张量简介

  • 张量:存储同一类型元素的容器。
  • 创建方式
    • torch.tensor():从多种数据结构创建,支持数值和布尔类型。
    • torch.Tensor():根据形状创建张量,默认数据类型为 float32

2. 创建张量

2.1 创建特定值张量

  • torch.ones(), torch.ones_like():创建全1张量。
  • torch.zeros(), torch.zeros_like():创建全0张量。
  • torch.full(), torch.full_like():创建指定值张量。

2.2 创建线性和随机张量

  • torch.arange(start, end, step):创建等差数列。
  • torch.linspace(start, end, steps):创建等间隔数列。
  • torch.rand():创建均匀分布随机数。
  • torch.randn():创建标准正态分布随机数。
  • torch.randint():创建指定范围内的随机整数。

2.3 设置随机种子

  • torch.random.initial_seed():设置随机数种子。
  • torch.manual_seed():设置随机数种子(常用)。

3. 数据类型转换

  • 使用 .type(torch.int16) 等方法进行数据类型转换。

4. 张量与NumPy数组互转

4.1 张量 -> NumPy数组

  • 共享内存.numpy()
  • 不共享内存.numpy.copy().ndarray()

4.2 NumPy数组 -> 张量

  • 不共享内存torch.tensor(ndarray)
  • 共享内存torch.from_numpy(ndarray)

4.3 标量张量 -> 数值

  • .item() 方法。

5. 张量运算

5.1 点乘与矩阵乘法

  • 点乘
    • 要求:形状相同。
    • 使用:torch.mul()* 运算符。
  • 矩阵乘法
    • 要求:第一个张量的最后一维大小等于第二个张量的倒数第二维大小。
    • 使用:torch.matmul()@ 运算符。

5.2 常用运算函数

  • sum(), mean(), max(), min():均支持 dim 参数。
    • dim=0:按列运算。
    • dim=1:按行运算。
    • 默认:对所有元素运算。
  • pow() / **, sqrt(), exp(), log(), log10(), log2()

5.3 广播机制

当两个张量形状不同时,可通过广播机制进行运算。

规则

  1. 若维度数量不同,小维度张量在最左边补1,直到维度数量相同。
  2. 若某维度大小不同,则将该维度较小的张量扩展为较大张量的大小(复制元素)。
  3. 若某维度大小不同且均不为1,则报错

7. 索引操作

分类:

  1. 简单行列索引
  2. 列表索引
  3. 范围索引
  4. 布尔索引
  5. 多维索引
    torch.manual_seed(42)
    t1 = torch.randint(1, 10, (5, 5))
    print(f't1:\n{t1}')
    # 场景1 : 简单行列索引
    print(t1[0, 0])  # 第一行第一列元素
    print(t1[:, 1])  # 第二列所有元素
    print('-' * 50)

    # 场景2: 列表索引
    rows = [0, 2, 4]
    cols = [1, 3, 4]
    print(t1[rows, cols])  # 取出 (0,1), (2,3), (4,4) 元素
    print('-' * 50)

    # 场景3: 范围索引
    # 左闭右开
    print(t1[1:4, 2:5])  # 取出第2到4行, 第3到5列的子矩阵
    print(t1[1::2, 0::2]) # 取出奇数行,偶数列, 步长为2
    print('-' * 50)

    # 场景4: 布尔索引
    mask = t1 > 5
    print(f'mask:\n{mask}')
    print(t1[mask])  # 取出所有大于5的元素
    print(t1[:, 2][t1[:, 2] > 5]) # 取出第3列大于5的元素
    print(t1[t1[:, 2] > 5]) # 取出第3列大于5的行数据
    print('-' * 50)

    # 场景5: 多维索引
    t2 = torch.randint(1, 10, (3, 4, 5)) # 3个4x5矩阵
    print(f't2:\n{t2}')
    print(t2[0, :, :])  # 取出第1个矩阵
    print(t2[:, 1, :])  # 取出所有矩阵的第2行
    print(t2[:, :, 2])  # 取出所有矩阵的第3列
    print(t2[1, 2, 3])  # 取出第2个矩阵的第3行第4列元素
    print('-' * 50)

8. 形状操作

方法 说明 注意
reshape() / view() 改变形状,返回新张量,原张量不变。 view() 要求张量连续存储(经transposepermute后可能不连续,需先调用.contiguous())。
squeeze() 去除所有大小为1的维度。 -
unsqueeze() 在指定位置插入大小为1的维度。 -
permute() 维度置换,改变维度顺序。 -
transpose() 交换指定的两个维度。 -
flatten() 将多维张量展平为一维。 -
contiguous() 将非连续存储的张量变为连续存储。 可用 .is_contiguous() 判断。

9. 张量的拼接与拆分

9.1 拼接

  • torch.cat():按指定维度拼接。所有张量在非拼接维度形状必须相同
  • torch.stack():在指定维度堆叠,新增一个维度。所有张量形状必须相同。

9.2 拆分

  • torch.split():按指定大小拆分。
  • torch.chunk():按指定块数拆分。
posted @ 2026-01-02 18:11  xggx  阅读(2)  评论(0)    收藏  举报