Pytorch 张量基础知识
1. 张量简介
- 张量:存储同一类型元素的容器。
- 创建方式:
torch.tensor():从多种数据结构创建,支持数值和布尔类型。torch.Tensor():根据形状创建张量,默认数据类型为float32。
2. 创建张量
2.1 创建特定值张量
torch.ones(),torch.ones_like():创建全1张量。torch.zeros(),torch.zeros_like():创建全0张量。torch.full(),torch.full_like():创建指定值张量。
2.2 创建线性和随机张量
torch.arange(start, end, step):创建等差数列。torch.linspace(start, end, steps):创建等间隔数列。torch.rand():创建均匀分布随机数。torch.randn():创建标准正态分布随机数。torch.randint():创建指定范围内的随机整数。
2.3 设置随机种子
torch.random.initial_seed():设置随机数种子。torch.manual_seed():设置随机数种子(常用)。
3. 数据类型转换
- 使用
.type(torch.int16)等方法进行数据类型转换。
4. 张量与NumPy数组互转
4.1 张量 -> NumPy数组
- 共享内存:
.numpy() - 不共享内存:
.numpy.copy()或.ndarray()
4.2 NumPy数组 -> 张量
- 不共享内存:
torch.tensor(ndarray) - 共享内存:
torch.from_numpy(ndarray)
4.3 标量张量 -> 数值
.item()方法。
5. 张量运算
5.1 点乘与矩阵乘法
- 点乘:
- 要求:形状相同。
- 使用:
torch.mul()或*运算符。
- 矩阵乘法:
- 要求:第一个张量的最后一维大小等于第二个张量的倒数第二维大小。
- 使用:
torch.matmul()或@运算符。
5.2 常用运算函数
sum(),mean(),max(),min():均支持dim参数。dim=0:按列运算。dim=1:按行运算。- 默认:对所有元素运算。
pow()/**,sqrt(),exp(),log(),log10(),log2()。
5.3 广播机制
当两个张量形状不同时,可通过广播机制进行运算。
规则:
- 若维度数量不同,小维度张量在最左边补1,直到维度数量相同。
- 若某维度大小不同,则将该维度较小的张量扩展为较大张量的大小(复制元素)。
- 若某维度大小不同且均不为1,则报错。
7. 索引操作
分类:
- 简单行列索引
- 列表索引
- 范围索引
- 布尔索引
- 多维索引
torch.manual_seed(42)
t1 = torch.randint(1, 10, (5, 5))
print(f't1:\n{t1}')
# 场景1 : 简单行列索引
print(t1[0, 0]) # 第一行第一列元素
print(t1[:, 1]) # 第二列所有元素
print('-' * 50)
# 场景2: 列表索引
rows = [0, 2, 4]
cols = [1, 3, 4]
print(t1[rows, cols]) # 取出 (0,1), (2,3), (4,4) 元素
print('-' * 50)
# 场景3: 范围索引
# 左闭右开
print(t1[1:4, 2:5]) # 取出第2到4行, 第3到5列的子矩阵
print(t1[1::2, 0::2]) # 取出奇数行,偶数列, 步长为2
print('-' * 50)
# 场景4: 布尔索引
mask = t1 > 5
print(f'mask:\n{mask}')
print(t1[mask]) # 取出所有大于5的元素
print(t1[:, 2][t1[:, 2] > 5]) # 取出第3列大于5的元素
print(t1[t1[:, 2] > 5]) # 取出第3列大于5的行数据
print('-' * 50)
# 场景5: 多维索引
t2 = torch.randint(1, 10, (3, 4, 5)) # 3个4x5矩阵
print(f't2:\n{t2}')
print(t2[0, :, :]) # 取出第1个矩阵
print(t2[:, 1, :]) # 取出所有矩阵的第2行
print(t2[:, :, 2]) # 取出所有矩阵的第3列
print(t2[1, 2, 3]) # 取出第2个矩阵的第3行第4列元素
print('-' * 50)
8. 形状操作
| 方法 | 说明 | 注意 |
|---|---|---|
reshape() / view() |
改变形状,返回新张量,原张量不变。 | view() 要求张量连续存储(经transpose或permute后可能不连续,需先调用.contiguous())。 |
squeeze() |
去除所有大小为1的维度。 | - |
unsqueeze() |
在指定位置插入大小为1的维度。 | - |
permute() |
维度置换,改变维度顺序。 | - |
transpose() |
交换指定的两个维度。 | - |
flatten() |
将多维张量展平为一维。 | - |
contiguous() |
将非连续存储的张量变为连续存储。 | 可用 .is_contiguous() 判断。 |
9. 张量的拼接与拆分
9.1 拼接
torch.cat():按指定维度拼接。所有张量在非拼接维度形状必须相同。torch.stack():在指定维度堆叠,新增一个维度。所有张量形状必须相同。
9.2 拆分
torch.split():按指定大小拆分。torch.chunk():按指定块数拆分。

浙公网安备 33010602011771号