【PyTorch】命令汇总表
| 所属操作 | 函数名 | 用处 | 使用频率 |
|---|---|---|---|
| 张量的创建 | torch.tensor() |
从列表 / 数组创建张量 | ★★★★★ |
torch.Tensor() |
创建默认 float32 类型张量 | ★★★☆☆ | |
torch.zeros() |
创建全 0 张量 | ★★★★★ | |
torch.ones() |
创建全 1 张量 | ★★★★★ | |
torch.eye() |
创建单位矩阵 | ★★★☆☆ | |
torch.arange() |
创建等差序列张量 | ★★★★☆ | |
torch.linspace() |
创建等间距序列张量 | ★★★☆☆ | |
torch.rand() |
创建 [0,1) 均匀分布随机张量 | ★★★★★ | |
torch.randn() |
创建标准正态分布随机张量 | ★★★★★ | |
torch.randint() |
创建指定范围随机整数张量 | ★★★☆☆ | |
| 张量的类型转换 | tensor.float() |
转换为 float32 类型 | ★★★★★ |
tensor.double() |
转换为 float64 类型 | ★★★☆☆ | |
tensor.long() |
转换为 int64 类型 | ★★★★★ | |
tensor.int() |
转换为 int32 类型 | ★★★☆☆ | |
tensor.bool() |
转换为 bool 类型 | ★★★☆☆ | |
torch.from_numpy() |
从 NumPy 数组创建张量 | ★★★★☆ | |
tensor.numpy() |
张量转换为 NumPy 数组 | ★★★★☆ | |
tensor.cuda() |
张量迁移到 GPU | ★★★★★ | |
tensor.cpu() |
张量迁移到 CPU | ★★★★★ | |
tensor.to() |
通用设备 / 类型转换 | ★★★★★ | |
| 张量数值计算 | torch.add()/+ |
张量元素级加法 | ★★★★★ |
torch.sub()/- |
张量元素级减法 | ★★★★★ | |
torch.mul()/* |
张量元素级乘法 | ★★★★★ | |
torch.div()// |
张量元素级除法 | ★★★★★ | |
torch.matmul()/@ |
张量矩阵乘法 | ★★★★★ | |
torch.sum() |
张量元素求和 | ★★★★★ | |
torch.mean() |
张量元素求均值 | ★★★★★ | |
torch.max() |
张量元素求最大值 | ★★★★★ | |
torch.min() |
张量元素求最小值 | ★★★★★ | |
torch.argmax() |
张量最大值索引 | ★★★★☆ | |
torch.argmin() |
张量最小值索引 | ★★★★☆ | |
torch.std() |
张量元素求标准差 | ★★★☆☆ | |
torch.var() |
张量元素求方差 | ★★★☆☆ | |
| 张量运算函数 | torch.abs() |
张量元素取绝对值 | ★★★★☆ |
torch.sqrt() |
张量元素开平方 | ★★★☆☆ | |
torch.exp() |
张量元素指数运算 | ★★★☆☆ | |
torch.pow() |
张量元素幂运算 | ★★★★★ | |
torch.log() |
张量元素对数运算 | ★★★☆☆ | |
torch.sin() |
张量元素正弦运算 | ★★★☆☆ | |
torch.cos() |
张量元素余弦运算 | ★★★☆☆ | |
torch.tanh() |
张量元素双曲正切运算 | ★★★★☆ | |
torch.gt()/> |
张量元素大于比较 | ★★★★☆ | |
torch.lt()/< |
张量元素小于比较 | ★★★★☆ | |
torch.eq()/== |
张量元素相等比较 | ★★★★☆ | |
torch.where() |
按条件选择张量元素 | ★★★★☆ | |
| 张量索引操作 | tensor[index] |
基础索引(取行 / 列 / 元素) | ★★★★★ |
tensor[start:end] |
切片索引 | ★★★★★ | |
tensor[:, index] |
列索引 | ★★★★★ | |
tensor[mask] |
布尔索引 | ★★★★☆ | |
torch.index_select() |
按指定索引选择张量元素 | ★★★☆☆ | |
torch.masked_select() |
按布尔掩码选择张量元素 | ★★★☆☆ | |
torch.gather() |
按索引从张量中收集元素 | ★★★☆☆ | |
| 张量形状操作 | tensor.shape/tensor.size() |
获取张量形状 | ★★★★★ |
tensor.reshape() |
改变张量形状 | ★★★★★ | |
tensor.view() |
改变张量形状(内存连续) | ★★★★★ | |
tensor.flatten() |
展平张量为一维 | ★★★★☆ | |
tensor.squeeze() |
移除维度为 1 的轴 | ★★★★☆ | |
tensor.unsqueeze() |
增加维度为 1 的轴 | ★★★★★ | |
tensor.permute() |
重排张量维度 | ★★★★☆ | |
tensor.transpose() |
交换张量两个维度 | ★★★★★ | |
| 张量拼接操作 | torch.cat() |
按指定维度拼接张量(不新增维度) | ★★★★★ |
torch.stack() |
堆叠张量(新增维度) | ★★★★☆ | |
| 自动微分模块 | tensor.requires_grad_() |
开启张量梯度追踪 | ★★★★★ |
tensor.backward() |
反向传播计算梯度 | ★★★★★ | |
tensor.grad |
获取张量梯度 | ★★★★★ | |
torch.no_grad() |
关闭梯度追踪上下文 | ★★★★★ | |
tensor.grad.zero_() |
清零张量梯度 | ★★★★★ |

浙公网安备 33010602011771号