为什么nn.Linear 的weight 是 (out_features, in_features)

在PyTorch的nn.Linear中,权重矩阵的形状为(out_features, in_features)。这是因为线性变换的实现方式为:

具体来说:

  1. 当创建nn.Linear(10, 60)时,in_features=10out_features=60,因此权重的形状是(60, 10)
  2. 输入张量t的形状为(2, 5, 10),与转置后的权重a.weight.T(形状(10, 60))相乘时,实际计算为:
    [
    t \in \mathbb{R}^{2 \times 5 \times 10}, \quad a.weight^\top \in \mathbb{R}^{10 \times 60} \implies t \ @ \ a.weight^\top \in \mathbb{R}^{2 \times 5 \times 60}
    ]
    这与直接调用a(t)的结果一致。

因此,a.weight的shape是(60, 10),而非(10, 60),这是PyTorch的设计约定,确保矩阵乘法能正确匹配输入和输出的维度。

posted @ 2025-03-26 00:24  xiezhengcai  阅读(149)  评论(0)    收藏  举报