在 PyTorch 的线性层(nn.Linear)中,权重矩阵的维度顺序可能与你手动计算时的直觉相反,这是导致你感到困惑的主要原因。我来解释一下这个现象背后的原理,并提供代码示例帮助你理解。
在数学中,线性变换通常表示为 y = Wx + b,其中:
- x 是输入向量(维度:
[input_size, 1])
- W 是权重矩阵(维度:
[output_size, input_size])
- y 是输出向量(维度:
[output_size, 1])
但在 PyTorch 中,nn.Linear 的权重矩阵 W 的维度是 [output_size, input_size],而输入 x 的维度通常是 [batch_size, input_size]。此时,PyTorch 会自动执行 y = xW^T + b(注意这里的 W^T 是权重矩阵的转置)。
这种设计有两个主要原因:
- 批量处理:PyTorch 需要同时处理多个样本(batch),因此输入维度是
[batch_size, input_size]。
- 内存连续性:按行存储数据(C 风格)更符合计算机内存访问模式,提高计算效率。
以下代码可以帮助你理解 nn.Linear 的权重矩阵维度:
import torch
import torch.nn as nn
如果你习惯按数学公式 y = Wx + b 手动计算,需要注意:
- 将 PyTorch 的权重矩阵 W 转置为 W^T(维度:
[input_size, output_size])。
- 将输入向量 x 保持为列向量(维度:
[input_size, 1])。
- PyTorch 实现:
y = x @ W^T + b,其中:
- x 维度:
[batch_size, input_size]
- W 维度:
[output_size, input_size]
- W^T 维度:
[input_size, output_size]
- 数学公式:
y = Wx + b,其中:
- x 维度:
[input_size, 1]
- W 维度:
[output_size, input_size]
这种差异是为了适应批量计算和内存优化,理解后就不需要每次都手动转置了!