用 3D 工具看懂矩阵乘法:张量运算可视化入门
用 3D 工具看懂矩阵乘法:张量运算可视化入门
本文面向 PyTorch 初学者:这篇文章帮你走入
@、torch.mm、torch.matmul的世界。
本文将:
- 分清楚 PyTorch 里四种乘法到底有什么区别;
- 用一个在线免费的 3D 可视化工具,让你直观体会矩阵乘法计算过程。
前言
深度学习里,张量乘法是预备知识,但很多书并没有系统讲解,默认读者已经掌握,可很多人学到这里都会卡住,因为 PyTorch 里光是"乘法"就有好几种写法,稍不留神就用错。
一、PyTorch 里的"乘法系列"
先上一张全景对比表:
| 运算名称 | NumPy 写法 | PyTorch 写法 | 运算符 | 一句话说明 |
|---|---|---|---|---|
| 逐元素相乘(Hadamard) | np.multiply(a, b) |
torch.mul(a, b) |
* |
形状相同,对应位置各自相乘 |
| 矩阵乘法(Matmul) | np.matmul(a, b) |
torch.matmul(a, b) |
@ |
线性代数规则,深度学习最常用 |
| 向量点积(Dot Product) | np.dot(a, b) |
torch.dot(a, b) |
— | 只能用于两个一维向量,结果是标量 |
| 严格矩阵乘法(2D only) | — | torch.mm(a, b) |
— | 仅限二维矩阵,比 matmul 更严格 |
1.1 逐元素相乘 *
最直观的乘法:两个形状完全相同的张量,对应位置各自相乘。
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = a * b
# tensor([ 4, 10, 18])
# 计算过程:1×4=4, 2×5=10, 3×6=18
还支持广播(Broadcasting),比如矩阵的每一行都乘以同一个向量:
a = torch.ones(3, 4) # shape: [3, 4]
b = torch.tensor([1., 2., 3., 4.]) # shape: [4]
c = a * b # 每一行都乘 b → shape: [3, 4]
记住: 用 * 是"各自乘自己那份",不是线性代数意义上的矩阵乘法。
1.2 向量点积 torch.dot()
点积专门用来计算两个一维向量的内积,结果是一个标量(单个数字)。
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.dot(a, b)
# 计算过程:1×4 + 2×5 + 3×6 = 4 + 10 + 18 = 32
# tensor(32)
⚠️
torch.dot()只接受一维张量,传入矩阵会报错。
1.3 矩阵乘法 @ / torch.matmul()
这是深度学习里用得最多的运算,全连接层的核心就是它:
y = x @ W + b
规则:左矩阵的列数必须等于右矩阵的行数。
A = torch.tensor([[1, 2],
[3, 4]]) # shape: [2, 2]
B = torch.tensor([[5, 6],
[7, 8]]) # shape: [2, 2]
C = A @ B
# tensor([[19, 22],
# [43, 50]])
手算一下 C[0, 0]:第一行 [1, 2] 与第一列 [5, 7] 做点积 → 1×5 + 2×7 = 19 ✅
torch.matmul() 还支持批量矩阵乘法(Batched Matmul),也就是同时对一批矩阵做乘法:
a = torch.randn(3, 4) # 单个矩阵 [3, 4]
b = torch.randn(5, 4, 2) # 一批矩阵 [5, 4, 2]
c = torch.matmul(a, b)
print(c.shape) # torch.Size([5, 3, 2])
# a 被广播为 (5, 3, 4),与 b 做批量乘法
1.4 严格矩阵乘法 torch.mm()
torch.mm() 和 torch.matmul() 在二维矩阵上的结果完全相同,区别是:
torch.mm()只接受二维张量,传入三维及以上的张量会报错;torch.matmul()支持任意维度。
# 仅限 2D
c = torch.mm(A, B) # 和 A @ B 结果一样
# 如果传入三维张量,torch.mm 会报错:
# RuntimeError: Expected 2D tensors...
什么时候用 torch.mm()? 当你想确保代码只处理二维矩阵,用它可以让错误"提前暴露",避免意外的广播行为。
二、用 mm 工具"看到"矩阵乘法
光看公式,矩阵乘法抽象难懂。这里介绍一个神器:
mm — Matmul Visualizations in 3D
🔗 https://bhosmer.github.io/mm
这是 Ben Hosmer 开发的开源工具(MIT 许可),能把矩阵乘法用 3D 动画直观展示出来,免费、免安装,打开网页就能用。
它登上了AI领域核心框架 PyTorch的官方博客 PyTorch官博,收获了 NVIDIA高级科学家Jim Fan 的盛赞,称其为“非常酷的可视化工具”。此外,包括“新智元”等在内的国内外主流科技社区也广泛报道,都证明了它强大的影响力。Pytorch最新工具mm

2.1 打开工具,认识界面
访问 https://bhosmer.github.io/mm,默认看到一个 L @ R 的三维可视化:
- L(左矩阵):矩阵乘法中的第一个矩阵,通常位于 3D 视图的左侧(或下方)。它的行数决定了结果矩阵的行数。
- R(右矩阵):矩阵乘法中的第二个矩阵,通常位于右侧(或上方)。它的列数决定了结果矩阵的列数。
- L @ R(结果矩阵):
L与R相乘得到的新矩阵,通常显示在正前方(或正后方)。它的每个元素由L的对应行与R的对应列做点积得到。
这个可视化工具用 3D 立方体的线条交织,生动展示了 L 的每一行如何与 R 的每一列进行运算,最终填充到结果矩阵的对应位置。
2.2 主要参数说明
左侧使用说明:这个区域主要提供参考信息,例如列出的 Zoom、Spin 等交互操作,可以理解为快捷键提示,方便你快速上手。
矩阵设置(右侧面板):
除了矩阵乘法 @,mm 工具还支持转置 T、加法 +、点乘 * 等运算,并能将它们组合成更复杂的表达式。
基本运算与函数
在 mm 的表达式 (expr) 栏中,你可以混合使用以下运算和函数:
| 类别 | 运算符 / 函数 | 说明 |
|---|---|---|
| 基本运算 | @ |
矩阵乘法,mm 的核心运算。 |
T |
转置(Transpose)。例如 R.T 表示矩阵 R 的转置。 |
|
+ |
矩阵加法。维度必须相同。 | |
* |
逐元素相乘(Element-wise Multiplication),也叫哈达玛积(Hadamard product)。 | |
| 常用函数 | softmax() |
按行应用 softmax 函数,常用于注意力机制的计算中。 |
sqrt() |
计算矩阵中每个元素的平方根。 | |
| 辅助符号 | 小括号 ( ) |
用于明确运算的优先级和组合表达式。 |
基本参数说明
矩阵:
| 参数 | 说明 |
|---|---|
name |
矩阵名称,默认 L 和 R |
h / w |
矩阵的行数(height)和列数(width) |
init |
初始化方式:row+major(按行填充)或 col+major(按列填充) |
min / max |
元素值范围,默认 -1 到 1 |
dropout |
随机置零比例,设 0 表示无 dropout |
布局与视觉:
| 参数 | 可选值 | 说明 |
|---|---|---|
gap |
数值 0–20 | 分块之间的间距,建议设 4,数据流更清楚 |
scheme |
blocks / zigzag / wheel |
矩阵的分块布局方式,blocks 最清晰 |
left placement |
left / right |
左矩阵放在哪侧 |
right placement |
top / left |
右矩阵放在哪侧 |
result placement |
front / back |
结果矩阵放在前方还是后方 |
2.3 3D 视图怎么操作
| 操作 | 方式 |
|---|---|
| 缩放 | 鼠标滚轮 / 双指捏合 |
| 旋转 | 鼠标按住左键进行拖拽 |
| 显示计算详情 | 长按 + 拖拽 |
2.4 实践:可视化一次矩阵乘法
拿上面那个 2×2 的例子来操作:
A = [[1, 2], # 2×2
[3, 4]]
B = [[5, 6], # 2×2
[7, 8]]
# C = A @ B
# C[0,0] = 1*5 + 2*7 = 19
# C[0,1] = 1*6 + 2*8 = 22
# C[1,0] = 3*5 + 4*7 = 43
# C[1,1] = 3*6 + 4*8 = 50
在 mm 中的操作步骤:
- 打开 https://bhosmer.github.io/mm
- 右侧面板将 L 设为
h=2, w=2,R 设为h=2, w=2 - 找到
animation区域,配置动画速度等,按空格键暂停(也可点击animation区域的暂停按钮) - 观察:L 的每一行(蓝色高亮)与 R 的每一列(红色高亮)滑入交叉,相乘后汇聚到结果矩阵对应位置
你会直观地看到:C[0,0] 是 L 第 0 行与 R 第 0 列的点积,C[1,1] 是 L 第 1 行与 R 第 1 列的点积……整个乘法过程一目了然。
2.5 计算量也能"看见"
对于 L(m×k) @ R(k×n),结果矩阵是 (m×n)。
每个输出元素需要 k 次乘加,总共 m × n × k 次运算。
你可以在 mm 中故意把矩阵设大(比如 32×64 @ 64×32),会看到三维空间里突然出现大量的运算"体积"——这就是深度学习训练慢、需要 GPU 加速的根本原因:每一层的矩阵乘法都在做天量的浮点运算。
2.6 彩蛋:可视化 Transformer 注意力机制
mm 还内置了一个 GPT-2 注意力头探索器:
🔗 https://bhosmer.github.io/mm/examples/attngpt2
可以可视化 Transformer 中的 Q × Kᵀ(注意力权重计算)和 O × V(值变换)两步矩阵乘法,直观感受注意力机制底层在做什么。
三、张量存储:storage、stride 和 contiguous
理解了乘法,再补充一个"幕后知识"——张量在内存里到底是怎么存的。这对理解转置、切片、view 等操作的行为很有帮助。
3.1 内存布局:行优先 vs. 列优先
张量(矩阵)在计算机内存中是一维的线性空间,如何把二维的 (行, 列) 映射到一维地址?主要有两种规则:
| 存储顺序 | 英文名 | 排列方式 | 典型框架 |
|---|---|---|---|
| 行优先 | Row-major | 先存完第 0 行的所有列,再存第 1 行,以此类推 | C/C++, Python (NumPy, PyTorch 默认), Go |
| 列优先 | Column-major | 先存完第 0 列的所有行,再存第 1 列,以此类推 | Fortran, MATLAB, R, Julia |
示例:矩阵 A
A = [[1, 2, 3],
[4, 5, 6]]
- 行优先内存顺序:
1, 2, 3, 4, 5, 6 - 列优先内存顺序:
1, 4, 2, 5, 3, 6
3.2 为什么这很重要?
转置 (transpose) 是真的交换元素吗?
- 物理转置:如果真的要交换元素的行列并重新排列内存,代价是 O(n²)。
- 元数据转置:大部分框架(如 PyTorch、NumPy)的
tensor.T或np.transpose只是修改了张量的元数据(步长stride和形状shape),并没有移动内存中的数据。此时内存布局依然是原来的顺序,但通过步长跳跃访问来“假装”已经转置。
切片 (slice) 是视图还是拷贝?
- 切片操作(如
tensor[1:3, :])通常返回原内存的视图,不复制数据,只是修改了起始偏移量和形状/步长。 - 如果你修改切片,原张量也会改变。
- 想要独立副本需显式调用
.clone()。
view() 与 reshape() 的区别
view():要求张量在内存中连续(contiguous),才能在不复制数据的情况下改变形状。如果不连续,会报错。reshape():自动判断,如果连续就用视图,否则先拷贝成连续再视图。- 为什么
view()要求连续?因为只有连续排列的数据,才能通过简单的数学公式重新映射到新形状。对于一个形状为(H, W)的连续张量,元素(i, j)的内存偏移量为i * W + j。
当张量在 Storage 里是按顺序紧密排列的,就叫 contiguous(连续)。转置、切片等操作之后,张量通常变得不连续。
⚠️
view()、narrow()、expand()等操作要求张量是连续的。如果报错RuntimeError: non-contiguous tensor,加一句.contiguous()就能解决。
四、广播(Broadcasting):形状不同也能算
当两个张量形状不同时,PyTorch 会尝试"广播"来自动对齐。看过了 mm 工具可视化后,你就应该理解了为什么要“广播”:数据不够该怎么运算,是否可以扩充来进行运算,以及什么情况不能进行扩充。
广播规则(从右向左对齐):
- 从右向左逐维度比较;
- 维度大小相等,或其中一个为 1(可以扩展),否则不可以运算 ;
- 缺少的维度视为 1。
a = torch.arange(3).reshape(3, 1) # shape: [3, 1]
b = torch.arange(2).reshape(1, 2) # shape: [1, 2]
result = a + b
# tensor([[0, 1],
# [1, 2],
# [2, 3]])
# a 沿列方向扩展为 [3, 2]
# b 沿行方向扩展为 [3, 2]
# 然后逐元素相加
常见形状组合一览:
| 形状 A | 形状 B | 结果形状 | 合法? |
|---|---|---|---|
| (3, 4) | (4,) | (3, 4) | ✅ |
| (3, 4) | (3, 1) | (3, 4) | ✅ |
| (3, 4) | (3,) | — | ❌ 4 ≠ 3 |
| (2, 3, 4) | (3, 4) | (2, 3, 4) | ✅ |
五、速查:选哪个乘法?
| 场景 | 用这个 | 理由 |
|---|---|---|
| 形状相同,对应位置乘 | * 或 torch.mul() |
最直观 |
| 两个一维向量求内积 | torch.dot() |
返回标量,语义明确 |
| 严格的二维矩阵乘法 | torch.mm() |
类型检查严格,不会意外广播 |
| 线性代数 / 深度学习 | @ 或 torch.matmul() |
最灵活,支持多维和广播 |
小结
| 知识点 | 关键词 |
|---|---|
| 四种乘法 | *、dot、mm、matmul / @ |
| 张量存储 | Storage → Stride → Contiguous |
| 广播机制 | 从右对齐,维度为1可扩展 |
| 可视化工具 | bhosmer.github.io/mm |
矩阵乘法是深度学习的"发动机",理解它的运算逻辑和内存机制,是读懂更高层模型(Transformer、CNN)的基础。希望借助 mm 工具,这些原本抽象的概念能在你脑海里留下一个具体的 3D 印象。
可视化工具:mm — Matmul Visualizations in 3D
参考文档:https://bhosmer.github.io/mm/ref.html
GitHub:https://github.com/bhosmer/mm
如果你对张量的切片索引也感到困惑,推荐阅读同系列的 NumPy vs Pandas vs Tensor 切片索引对比图解 ,它系统地梳理了三者切片规则的差异和"视图 vs 副本"的坑。
📢 声明:本文借助AI辅助工具进行资料整理与初稿生成,所有内容均经过作者本人的详细核对、修改与编排,文责自负。

浙公网安备 33010602011771号