行行重行行:用 Prefix Scan 做方阵乘法

在离散时间的量子演化里,我们经常面对一串小规模 dense 方阵。它们可能来自 Trotter 分解、分段常哈密顿量、控制脉冲序列,或者某个时间网格上的局部传播子。把这些矩阵记作 \(A_0, A_1, \dots, A_{n-1}\),最直接的问题是计算最终连乘 \(A_0 A_1 \cdots A_{n-1}\)。但很多时候我们不只要最后一步,还要每个时间点的中间演化:\(P_i = A_0 A_1 \cdots A_i\)

这就是矩阵乘法版本的 prefix scan。和标量前缀和相比,它没有换一个名字就变得神秘;区别只在于运算从加法变成了矩阵乘法。矩阵乘法满足结合律,所以可以被 scan;矩阵乘法一般不满足交换律,所以顺序不能乱。

上面仅仅是一个物理背景,即使不了解也没有关系。我们现在统一一下具体问题:已知同样大小的方阵 \(A_0,A_1 \cdots A_{n-1}\),在并行计算资源充足的情况下,要快速求解所有 \(P_i = A_0 A_1 \cdots A_i\)

从顺序前缀积开始

如果 \(A_i\) 都是 \(m \times m\) dense 方阵,顺序代码很直接:先令 \(P_0 = A_0\),然后依次计算 \(P_i = P_{i-1} A_i\)。这里每一步都是一次小矩阵乘法,单次代价约为 \(O(m^3)\),总工作量是 \(O(n m^3)\),关键路径长度也是 \(O(n)\) 次矩阵乘法。

这个算法的问题不是总工作量太大,而是时间方向上的依赖太硬。\(P_7\) 必须等 \(P_6\)\(P_6\) 必须等 \(P_5\),整个计算像一条绷直的绳子。在 CPU 多核或 GPU 上,如果 \(m\) 很小,每个矩阵乘法本身可能不足以填满设备;真正的并行机会藏在矩阵序列之间,而顺序递推没有把它暴露出来。

树形规约:先把区间乘积算出来

把数组看成一棵二叉树。叶子是原始矩阵,内部节点表示一个连续区间的乘积。例如长度为 8 时,底层先计算相邻二元组:\(B_0 = A_0 A_1\)\(B_1 = A_2 A_3\)\(B_2 = A_4 A_5\)\(B_3 = A_6 A_7\)

再往上计算 \(C_0 = B_0 B_1 = A_0 A_1 A_2 A_3\)\(C_1 = B_2 B_3 = A_4 A_5 A_6 A_7\)。最终根节点就是 \(C_0 C_1\)

这一步叫 upsweep 或 reduce。它只回答“每个区间的总乘积是什么”,还没有给出所有前缀。树上每个内部节点保存的是 \(T(l,r) = A_l A_{l+1} \cdots A_r\)。合并两个相邻区间 \([l, mid]\)\([mid+1, r]\) 时,只能写成 \(T(l,r) = T(l,mid) T(mid+1,r)\),左右顺序不能反。

这点在量子演化里尤其容易踩坑。若你的物理约定是 \(|\psi_{k+1}\rangle = U_k |\psi_k\rangle\),那么 \(|\psi_n\rangle = U_{n-1} \cdots U_1 U_0 |\psi_0\rangle\),时间后发生的算符在左边。本文为了讲 scan,把前缀写成 \(A_0 A_1 \cdots A_i\)。实际实现时只要统一存储顺序和乘法方向即可,可以把 \(A_i\) 定义成已经按代码中的乘法顺序排列的块,也可以改做反向 scan。算法依赖结合律,不依赖你采用哪套物理记号。

image

反向传播:把左侧上下文送到每个节点

规约树算出了区间乘积,接下来还要算出每一个前缀乘积。、prefix scan 还差一个信息:每个区间左边已经累积了什么。设某个节点表示区间 \([l,r]\),从根向下传给它一个上下文矩阵 \(S\),含义是“这个区间之前所有矩阵的乘积”。对于典型的 inclusive scan 形式,叶子 \(A_i\) 拿到上下文 \(S_i = A_0 \cdots A_{i-1}\) 后,输出就是 \(S_i A_i\)。对于第一个叶子,\(S_0\) 是单位矩阵 \(I\)

根节点的上下文是 \(I\)。如果一个节点被分成左孩子 \(L=[l,mid]\) 和右孩子 \(R=[mid+1,r]\),并且当前上下文为 \(S\),那么传给左孩子的上下文仍然是 \(S\),传给右孩子的上下文则是 \(S\cdot T(L)\),其中 \(T(L)\) 是左孩子的区间乘积。也就是说,右子树里的每个元素都需要先乘上整段左子树的贡献。

这一阶段常叫 downsweep。它和反向传播这个词有一点形似:信息从根向叶子传播,每个分支携带一个外部上下文。不过这里没有梯度,只有前缀上下文。把它写成递归形式会很清楚:

scan(node, S):
    if node is leaf A_i:
        P_i = S * A_i
    else:
        L, R = node.left, node.right
        scan(L, S)
        scan(R, S * T(L))

这个伪代码已经包含了实现时最重要的边界条件。单位元必须是 \(m \times m\) 的单位矩阵;右孩子上下文必须乘上左孩子的区间乘积;如果矩阵乘法方向采用另一套约定,对应的上下文公式也要整体翻转,不能只改某一行。

复杂度和并行性

顺序算法的矩阵乘法次数是 \(n-1\),工作量为 \(O(nm^3)\),关键路径为 \(O(nm^3)\)。树形 scan 的 upsweep 有 \(n-1\) 个内部节点,因此需要 \(n-1\) 次区间合并;downsweep 中每个内部节点通常还要为右孩子计算一次新的上下文,因此又接近 \(n-1\) 次矩阵乘法。最后每个叶子的输出若按 \(S_i A_i\) 显式计算,还会有 \(n\) 次乘法;有些实现会把叶子输出融入 downsweep,减少一部分常数,但数量级仍是 \(O(nm^3)\)

因此 scan 的工作量不比顺序递推更省。它的价值在 depth:理想二叉树上,upsweep 需要 \(O(\log n)\) 层,downsweep 也需要 \(O(\log n)\) 层。若每层有足够多独立节点可并行执行,关键路径从 \(O(n)\) 次矩阵乘法下降到 \(O(\log n)\) 次矩阵乘法。用 Brent 定理的语言说,在 \(p\) 个处理单元上,运行时间大致受 \(O(nm^3/p + m^3\log n)\) 控制,忽略调度、访存和同步常数。

真实机器上,常数并不小。每一层之间有同步,区间乘积需要写入和读取,小矩阵乘法本身还可能受寄存器、cache、SIMD 宽度或 GPU occupancy 影响。当 \(n\) 不够大、\(m\) 稍大到单个 GEMM 已经很饱满,或者只需要最后一个乘积而不需要所有中间前缀时,scan 未必划算。相反,当 \(m\) 小、\(n\) 长、每个时间点的中间演化都要保留,或者后续还要对每个 \(P_i\) 批量作用到态矢量或观测量上,前缀 scan 更容易把时间维度上的并行性释放出来。

实现上的几个细节

小 dense 方阵最好连续存放,例如布局为 A[n][m][m] 或扁平化后的 A[n * m * m]。对 CPU 来说,\(m\) 很小时手写固定尺寸 kernel 往往比调用通用 GEMM 更合适;对 GPU 来说,可以让一个 warp 或一个 thread block 处理一个矩阵乘法,具体取决于 \(m\)、批量大小和寄存器压力。scan 的树层级通常按数组下标组织,避免真实构建指针树。

\(n\) 不是二的幂时也没有本质困难。常见做法是把缺失叶子补成单位矩阵,或者在每层合并时判断右邻居是否存在。补单位矩阵逻辑更统一,但会做一些无用乘法;分支判断节省工作,却可能让 GPU kernel 更复杂。对短序列,简单性往往比省掉几个乘法更重要。

数值误差也要留意。矩阵乘法满足数学上的结合律,但浮点数乘法不严格结合。树形加括号和顺序递推的舍入路径不同,结果可能有微小差异。对于幺正演化,连续相乘还可能慢慢偏离幺正性;必要时可以监控 \(P_i^\dagger P_i - I\),或在物理模型允许时做重正交化、投影回幺正群、使用更稳定的参数化表示。并行化不应该悄悄改变你对误差的容忍标准。

posted @ 2026-05-27 17:07  Ofnoname  阅读(7)  评论(0)    收藏  举报