用 3D 工具看懂矩阵乘法:张量运算可视化入门

用 3D 工具看懂矩阵乘法:张量运算可视化入门

本文面向 PyTorch 初学者:这篇文章帮你走入 @torch.mmtorch.matmul 的世界。
本文将:

  1. 分清楚 PyTorch 里四种乘法到底有什么区别;
  2. 用一个在线免费的 3D 可视化工具,让你直观体会矩阵乘法计算过程。


前言

深度学习里,张量乘法是预备知识,但很多书并没有系统讲解,默认读者已经掌握,可很多人学到这里都会卡住,因为 PyTorch 里光是"乘法"就有好几种写法,稍不留神就用错。


一、PyTorch 里的"乘法系列"

先上一张全景对比表:

运算名称 NumPy 写法 PyTorch 写法 运算符 一句话说明
逐元素相乘(Hadamard) np.multiply(a, b) torch.mul(a, b) * 形状相同,对应位置各自相乘
矩阵乘法(Matmul) np.matmul(a, b) torch.matmul(a, b) @ 线性代数规则,深度学习最常用
向量点积(Dot Product) np.dot(a, b) torch.dot(a, b) 只能用于两个一维向量,结果是标量
严格矩阵乘法(2D only) torch.mm(a, b) 仅限二维矩阵,比 matmul 更严格

1.1 逐元素相乘 *

最直观的乘法:两个形状完全相同的张量,对应位置各自相乘。

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

result = a * b
# tensor([ 4, 10, 18])
# 计算过程:1×4=4, 2×5=10, 3×6=18

还支持广播(Broadcasting),比如矩阵的每一行都乘以同一个向量:

a = torch.ones(3, 4)           # shape: [3, 4]
b = torch.tensor([1., 2., 3., 4.])  # shape: [4]

c = a * b  # 每一行都乘 b → shape: [3, 4]

记住:* 是"各自乘自己那份",不是线性代数意义上的矩阵乘法。


1.2 向量点积 torch.dot()

点积专门用来计算两个一维向量的内积,结果是一个标量(单个数字)

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])

result = torch.dot(a, b)
# 计算过程:1×4 + 2×5 + 3×6 = 4 + 10 + 18 = 32
# tensor(32)

⚠️ torch.dot() 只接受一维张量,传入矩阵会报错。


1.3 矩阵乘法 @ / torch.matmul()

这是深度学习里用得最多的运算,全连接层的核心就是它:

y = x @ W + b

规则:左矩阵的列数必须等于右矩阵的行数。

A = torch.tensor([[1, 2],
                  [3, 4]])   # shape: [2, 2]
B = torch.tensor([[5, 6],
                  [7, 8]])   # shape: [2, 2]

C = A @ B
# tensor([[19, 22],
#         [43, 50]])

手算一下 C[0, 0]:第一行 [1, 2] 与第一列 [5, 7] 做点积 → 1×5 + 2×7 = 19

torch.matmul() 还支持批量矩阵乘法(Batched Matmul),也就是同时对一批矩阵做乘法:

a = torch.randn(3, 4)       # 单个矩阵 [3, 4]
b = torch.randn(5, 4, 2)    # 一批矩阵 [5, 4, 2]

c = torch.matmul(a, b)
print(c.shape)  # torch.Size([5, 3, 2])
# a 被广播为 (5, 3, 4),与 b 做批量乘法

1.4 严格矩阵乘法 torch.mm()

torch.mm()torch.matmul()二维矩阵上的结果完全相同,区别是:

  • torch.mm() 只接受二维张量,传入三维及以上的张量会报错;
  • torch.matmul() 支持任意维度。
# 仅限 2D
c = torch.mm(A, B)  # 和 A @ B 结果一样

# 如果传入三维张量,torch.mm 会报错:
# RuntimeError: Expected 2D tensors...

什么时候用 torch.mm() 当你想确保代码只处理二维矩阵,用它可以让错误"提前暴露",避免意外的广播行为。


二、用 mm 工具"看到"矩阵乘法

光看公式,矩阵乘法抽象难懂。这里介绍一个神器:

mm — Matmul Visualizations in 3D
🔗 https://bhosmer.github.io/mm

这是 Ben Hosmer 开发的开源工具(MIT 许可),能把矩阵乘法用 3D 动画直观展示出来,免费、免安装,打开网页就能用。

它登上了AI领域核心框架 PyTorch的官方博客 PyTorch官博,收获了 NVIDIA高级科学家Jim Fan 的盛赞,称其为“非常酷的可视化工具”。此外,包括“新智元”等在内的国内外主流科技社区也广泛报道,都证明了它强大的影响力。Pytorch最新工具mm

image


2.1 打开工具,认识界面

访问 https://bhosmer.github.io/mm,默认看到一个 L @ R 的三维可视化:

  • L(左矩阵):矩阵乘法中的第一个矩阵,通常位于 3D 视图的左侧(或下方)。它的行数决定了结果矩阵的行数。
  • R(右矩阵):矩阵乘法中的第二个矩阵,通常位于右侧(或上方)。它的列数决定了结果矩阵的列数。
  • L @ R(结果矩阵)L 与 R 相乘得到的新矩阵,通常显示在正前方(或正后方)。它的每个元素由 L 的对应行与 R 的对应列做点积得到。

这个可视化工具用 3D 立方体的线条交织,生动展示了 L 的每一行如何与 R 的每一列进行运算,最终填充到结果矩阵的对应位置。


2.2 主要参数说明

左侧使用说明:这个区域主要提供参考信息,例如列出的 ZoomSpin 等交互操作,可以理解为快捷键提示,方便你快速上手。

矩阵设置(右侧面板):

除了矩阵乘法 @mm 工具还支持转置 T、加法 +、点乘 * 等运算,并能将它们组合成更复杂的表达式。

基本运算与函数

mm 的表达式 (expr) 栏中,你可以混合使用以下运算和函数:

类别 运算符 / 函数 说明
基本运算 @ 矩阵乘法,mm 的核心运算。
T 转置(Transpose)。例如 R.T 表示矩阵 R 的转置。
+ 矩阵加法。维度必须相同。
* 逐元素相乘(Element-wise Multiplication),也叫哈达玛积(Hadamard product)。
常用函数 softmax() 按行应用 softmax 函数,常用于注意力机制的计算中。
sqrt() 计算矩阵中每个元素的平方根。
辅助符号 小括号 ( ) 用于明确运算的优先级和组合表达式。

基本参数说明

矩阵:

参数 说明
name 矩阵名称,默认 L 和 R
h / w 矩阵的行数(height)和列数(width)
init 初始化方式:row+major(按行填充)或 col+major(按列填充)
min / max 元素值范围,默认 -1 到 1
dropout 随机置零比例,设 0 表示无 dropout

布局与视觉:

参数 可选值 说明
gap 数值 0–20 分块之间的间距,建议设 4,数据流更清楚
scheme blocks / zigzag / wheel 矩阵的分块布局方式,blocks 最清晰
left placement left / right 左矩阵放在哪侧
right placement top / left 右矩阵放在哪侧
result placement front / back 结果矩阵放在前方还是后方

2.3 3D 视图怎么操作

操作 方式
缩放 鼠标滚轮 / 双指捏合
旋转 鼠标按住左键进行拖拽
显示计算详情 长按 + 拖拽

2.4 实践:可视化一次矩阵乘法

拿上面那个 2×2 的例子来操作:

A = [[1, 2],   # 2×2
     [3, 4]]
B = [[5, 6],   # 2×2
     [7, 8]]

# C = A @ B
# C[0,0] = 1*5 + 2*7 = 19
# C[0,1] = 1*6 + 2*8 = 22
# C[1,0] = 3*5 + 4*7 = 43
# C[1,1] = 3*6 + 4*8 = 50

在 mm 中的操作步骤:

  1. 打开 https://bhosmer.github.io/mm
  2. 右侧面板将 L 设为 h=2, w=2,R 设为 h=2, w=2
  3. 找到animation区域,配置动画速度等,按空格键暂停(也可点击 animation 区域的暂停按钮)
  4. 观察:L 的每一行(蓝色高亮)与 R 的每一列(红色高亮)滑入交叉,相乘后汇聚到结果矩阵对应位置

你会直观地看到:C[0,0] 是 L 第 0 行与 R 第 0 列的点积,C[1,1] 是 L 第 1 行与 R 第 1 列的点积……整个乘法过程一目了然。


2.5 计算量也能"看见"

对于 L(m×k) @ R(k×n),结果矩阵是 (m×n)
每个输出元素需要 k 次乘加,总共 m × n × k 次运算。

你可以在 mm 中故意把矩阵设大(比如 32×64 @ 64×32),会看到三维空间里突然出现大量的运算"体积"——这就是深度学习训练慢、需要 GPU 加速的根本原因:每一层的矩阵乘法都在做天量的浮点运算


2.6 彩蛋:可视化 Transformer 注意力机制

mm 还内置了一个 GPT-2 注意力头探索器:

🔗 https://bhosmer.github.io/mm/examples/attngpt2

可以可视化 Transformer 中的 Q × Kᵀ(注意力权重计算)和 O × V(值变换)两步矩阵乘法,直观感受注意力机制底层在做什么。


三、张量存储:storage、stride 和 contiguous

理解了乘法,再补充一个"幕后知识"——张量在内存里到底是怎么存的。这对理解转置、切片、view 等操作的行为很有帮助。

3.1 内存布局:行优先 vs. 列优先

张量(矩阵)在计算机内存中是一维的线性空间,如何把二维的 (行, 列) 映射到一维地址?主要有两种规则:

存储顺序 英文名 排列方式 典型框架
行优先 Row-major 先存完第 0 行的所有列,再存第 1 行,以此类推 C/C++, Python (NumPy, PyTorch 默认), Go
列优先 Column-major 先存完第 0 列的所有行,再存第 1 列,以此类推 Fortran, MATLAB, R, Julia

示例:矩阵 A

A = [[1, 2, 3],
     [4, 5, 6]]
  • 行优先内存顺序:1, 2, 3, 4, 5, 6
  • 列优先内存顺序:1, 4, 2, 5, 3, 6

3.2 为什么这很重要?

转置 (transpose) 是真的交换元素吗?

  • 物理转置:如果真的要交换元素的行列并重新排列内存,代价是 O(n²)。
  • 元数据转置:大部分框架(如 PyTorch、NumPy)的 tensor.Tnp.transpose 只是修改了张量的元数据(步长 stride 和形状 shape),并没有移动内存中的数据。此时内存布局依然是原来的顺序,但通过步长跳跃访问来“假装”已经转置。

切片 (slice) 是视图还是拷贝?

  • 切片操作(如 tensor[1:3, :])通常返回原内存的视图,不复制数据,只是修改了起始偏移量和形状/步长。
  • 如果你修改切片,原张量也会改变。
  • 想要独立副本需显式调用 .clone()

view()reshape() 的区别

  • view():要求张量在内存中连续(contiguous),才能在不复制数据的情况下改变形状。如果不连续,会报错。
  • reshape():自动判断,如果连续就用视图,否则先拷贝成连续再视图。
  • 为什么 view() 要求连续?因为只有连续排列的数据,才能通过简单的数学公式重新映射到新形状。对于一个形状为 (H, W) 的连续张量,元素 (i, j) 的内存偏移量为 i * W + j

当张量在 Storage 里是按顺序紧密排列的,就叫 contiguous(连续)。转置、切片等操作之后,张量通常变得不连续。

⚠️ view()narrow()expand() 等操作要求张量是连续的。如果报错 RuntimeError: non-contiguous tensor,加一句 .contiguous() 就能解决。


四、广播(Broadcasting):形状不同也能算

当两个张量形状不同时,PyTorch 会尝试"广播"来自动对齐。看过了 mm 工具可视化后,你就应该理解了为什么要“广播”:数据不够该怎么运算,是否可以扩充来进行运算,以及什么情况不能进行扩充。

广播规则(从右向左对齐):

  1. 从右向左逐维度比较;
  2. 维度大小相等,或其中一个为 1(可以扩展),否则不可以运算 ;
  3. 缺少的维度视为 1。
a = torch.arange(3).reshape(3, 1)   # shape: [3, 1]
b = torch.arange(2).reshape(1, 2)   # shape: [1, 2]

result = a + b
# tensor([[0, 1],
#         [1, 2],
#         [2, 3]])
# a 沿列方向扩展为 [3, 2]
# b 沿行方向扩展为 [3, 2]
# 然后逐元素相加

常见形状组合一览:

形状 A 形状 B 结果形状 合法?
(3, 4) (4,) (3, 4)
(3, 4) (3, 1) (3, 4)
(3, 4) (3,) ❌ 4 ≠ 3
(2, 3, 4) (3, 4) (2, 3, 4)

五、速查:选哪个乘法?

场景 用这个 理由
形状相同,对应位置乘 *torch.mul() 最直观
两个一维向量求内积 torch.dot() 返回标量,语义明确
严格的二维矩阵乘法 torch.mm() 类型检查严格,不会意外广播
线性代数 / 深度学习 @torch.matmul() 最灵活,支持多维和广播

小结

知识点 关键词
四种乘法 *dotmmmatmul / @
张量存储 Storage → Stride → Contiguous
广播机制 从右对齐,维度为1可扩展
可视化工具 bhosmer.github.io/mm

矩阵乘法是深度学习的"发动机",理解它的运算逻辑和内存机制,是读懂更高层模型(Transformer、CNN)的基础。希望借助 mm 工具,这些原本抽象的概念能在你脑海里留下一个具体的 3D 印象。


可视化工具:mm — Matmul Visualizations in 3D
参考文档:https://bhosmer.github.io/mm/ref.html
GitHub:https://github.com/bhosmer/mm

如果你对张量的切片索引也感到困惑,推荐阅读同系列的 NumPy vs Pandas vs Tensor 切片索引对比图解 ,它系统地梳理了三者切片规则的差异和"视图 vs 副本"的坑。

📢 声明:本文借助AI辅助工具进行资料整理与初稿生成,所有内容均经过作者本人的详细核对、修改与编排,文责自负。

posted @ 2026-04-21 15:48  Lyn_Li  阅读(27)  评论(0)    收藏  举报