整体理解pai0-具身智能-PyTorch einsum 完全教程-11

1. 基础概念

einsum = Einstein Summation (爱因斯坦求和约定)

用简洁的字符串表示复杂的张量运算(乘法、求和、转置等)

2. 基础语法

torch.einsum("equation", tensor1, tensor2, ...)
字母代表维度
相同字母会进行对应相乘
输出中不出现的字母会被求和消除
逗号分隔不同的输入张量
箭头 -> 指定输出维度

Level 1: 向量点积

# 传统方法
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.dot(a, b)  # 1*4 + 2*5 + 3*6 = 32

# einsum 方法
result = torch.einsum('i,i->', a, b)
#                      ↑ ↑  ↑
#                      a b  输出(标量)

i: a 的第 0 维,b 的第 0 维
两个 i 相同 → 对应元素相乘
输出没有 i → 求和

Level 2: 矩阵乘法

# 传统方法
A = torch.randn(3, 4)  # [3, 4]
B = torch.randn(4, 5)  # [4, 5]
C = torch.mm(A, B)     # [3, 5]

# einsum 方法
C = torch.einsum('ik,kj->ij', A, B)
#                 ↑↑  ↑↑  ↑↑
#                 A   B   输出

解析:

A.shape = (3, 4)  # i=3, k=4
B.shape = (4, 5)  # k=4, j=5

# 运算: C[i,j] = Σ_k A[i,k] * B[k,j]
# k 出现在两边但不在输出 → 求和消除
# i, j 在输出 → 保留

C.shape = (3, 5)  # i=3, j=5

Level 3: 批次矩阵乘法(Transformer中常用)

# Batch Matrix Multiplication
A = torch.randn(2, 3, 4)  # [batch, n, k]
B = torch.randn(2, 4, 5)  # [batch, k, m]

# einsum 方法
C = torch.einsum('bik,bkj->bij', A, B)
#                 ↑             ↑
#              batch维度     batch维度
A.shape = (2, 3, 4)  # b=2, i=3, k=4
B.shape = (2, 4, 5)  # b=2, k=4, j=5

# 运算: C[b,i,j] = Σ_k A[b,i,k] * B[b,k,j]
# b 在输出 → 保留(不求和)
# k 不在输出 → 求和消除

C.shape = (2, 3, 5)  # [batch, n, m]

4. PI0 代码中的实际例子

例子1: QKV 投影 (gemma.py:183)

qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))

# 输入
x.shape = (B, S, D)
# B = batch_size (例如 2)
# S = sequence_length (例如 512)
# D = hidden_dim (例如 2048)

# 权重
weight.shape = (3, K, D, H)
# 3 = Q, K, V 三个矩阵
# K = num_kv_heads (例如 1)
# D = hidden_dim (2048)
# H = head_dim (256)

# einsum: "BSD,3KDH->3BSKH"
#          ↑    ↑      ↑
#          x  weight  输出

# 维度对应:
# B: batch (保留)
# S: sequence (保留)
# D: hidden_dim (求和消除,因为不在输出)
# 3: QKV (保留)
# K: num_heads (保留)
# H: head_dim (保留)

# 输出
output.shape = (3, B, S, K, H)
# 例如: (3, 2, 512, 1, 256)

例子2: 注意力计算 (gemma.py:217)

logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k)

解析:

# 输入
q.shape = (B, T, K, G, H)
# B = batch_size
# T = query_length (例如 512)
# K = num_kv_heads (1)
# G = group_size (8, 因为8个query heads / 1个kv head)
# H = head_dim (256)

k.shape = (B, S, K, H)
# S = key_length (例如 512)

# einsum: "BTKGH,BSKH->BKGTS"
#          ↑      ↑     ↑
#          q      k    输出

# 维度对应:
# B: batch (保留)
# T: query_length (保留)
# K: num_kv_heads (保留)
# G: group_size (保留)
# H: head_dim (求和消除!)
# S: key_length (保留)

# 输出
logits.shape = (B, K, G, T, S)

# 语义: logits[b,k,g,t,s] = Σ_h q[b,t,k,g,h] * k[b,s,k,h]
#       即: query位置t 对 key位置s 的注意力分数

T query的长度
S key的长度
G group_size

例子3: 注意力输出 (gemma.py:230)

encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
解析:

# 输入
probs.shape = (B, K, G, T, S)  # 注意力权重(softmax后)
v.shape = (B, S, K, H)          # Value

# einsum: "BKGTS,BSKH->BTKGH"
#          ↑      ↑     ↑
#        probs    v    输出

# 维度对应:
# B: batch (保留)
# K: num_kv_heads (保留)
# G: group_size (保留)
# T: query_length (保留)
# S: key_length (求和消除!) <- 加权求和
# H: head_dim (保留)

# 输出
encoded.shape = (B, T, K, G, H)

# 语义: encoded[b,t,k,g,h] = Σ_s probs[b,k,g,t,s] * v[b,s,k,h]
#       即: 用注意力权重加权 value

维度对应:
B: batch (保留)
K: num_kv_heads (保留)
G: group_size (保留)
T: query_length (保留)
H: head_dim (保留)

5. 常见模式总结

模式1: 矩阵乘法

# 2D
'ik,kj->ij'  # (i,k) @ (k,j) = (i,j)

# 3D (batch)
'bik,bkj->bij'  # (b,i,k) @ (b,k,j) = (b,i,j)

# 4D
'bhik,bhkj->bhij'  # 多头注意力

模式2: 外积

# 向量外积
'i,j->ij'  # (i,) ⊗ (j,) = (i,j)

# 批次外积
'bi,bj->bij'

(i,) ⊗ (j,) = (i, j),表示外积,维度相乘得到二维矩阵。

image
'i,j->ij' 表示将两个一维向量的所有元素两两相乘,生成一个二维矩阵,也就是向量的 外积(outer product)。

模式3: 求和

# 沿某个维度求和
'ijk->ij'   # 对k求和
'ijk->ik'   # 对j求和
'ijk->'     # 全部求和(标量)

模式4: 转置

'ij->ji'    # 转置
'ijk->ikj'  # 交换维度

模式5: 对角线

'ii->i'     # 提取对角线
'bii->bi'   # 批次对角线

6. 调试技巧

技巧1: 写出维度

# 先写出每个张量的维度
A: (3, 4)  # i=3, k=4
B: (4, 5)  # k=4, j=5

# 再写 einsum
'ik,kj->ij'

# 验证: k 求和消除,输出 (i, j) = (3, 5) ✓

技巧2: 分步理解

result = torch.einsum('bik,bkj->bij', A, B)

# 步骤1: 找共同维度
# b: 共同(batch)
# k: 共同(求和)

# 步骤2: 找独有维度
# i: 只在 A
# j: 只在 B

# 步骤3: 确定输出
# b: 保留(在输出中)
# i: 保留(在输出中)
# j: 保留(在输出中)
# k: 消除(不在输出中)

技巧3: 用注释

q = torch.einsum(
    'BTD,NDH->BTNH',  # Query projection
    x,      # [B, T, D] = [batch, seq, hidden]
    w_q,    # [N, D, H] = [heads, hidden, head_dim]
)           # → [B, T, N, H]

7. 练习题

# 1. 简单点积
'i,i->'

# 2. 批次矩阵乘法
'bmn,bnk->bmk'

# 3. 多头注意力
'bhqd,bhkd->bhqk'

# 4. 位置编码
'm,d->md'

# 5. 交叉注意力
'bid,bjd->bij'

希望这个教程能帮你理解 einsum!关键是:
把字母当作维度的名字
相同字母 = 对应相乘
输出中没有的字母 = 求和消除

posted @ 2025-10-24 17:24  jack-chen666  阅读(10)  评论(0)    收藏  举报