DeepSeek mHC & FlashAttention 2
技术日报 2026-03-23
技术一:DeepSeek mHC(流形约束超连接)—— 解决超深Transformer训练稳定性难题
1. 技术背景与动机
随着大语言模型(LLM)参数规模和深度的不断增长,超深Transformer网络面临着严重的训练稳定性问题。当网络层数增加到数百甚至上千层时,容易出现以下问题:
- 表示崩溃(Representation Collapse):深层网络中特征表示逐渐趋同,失去多样性
- 梯度消失/爆炸:深层网络难以有效传递梯度信号
- 优化困难:损失曲面变得极其复杂,难以收敛
2026年初,DeepSeek团队在论文《mHC: Manifold-Constrained Hyper-Connections》中提出了mHC(流形约束超连接)架构,旨在解决超深网络的训练难题。
2. 核心概念与原理
2.1 什么是超连接(Hyper-Connections)?
传统Transformer采用残差连接(Residual Connections),形式为:
y = x + f(x)
而超连接允许网络中多个层之间建立直接连接,不仅限于相邻层。例如:
y_i = Σ_{j < i} α_{i,j} * f_j(x_j) + β_{i,j} * x_j
其中,α和β是可学习的权重矩阵。
2.2 mHC的核心创新
mHC(Manifold-Constrained Hyper-Connections)的关键在于引入流形约束,限制超连接权重的学习空间:
-
双随机矩阵约束(Doubly Stochastic Constraint):
- 每行和每列的和都为1
- 确保权重分布的平衡性
-
几何结构保持:
- 通过流形约束保持特征空间的几何特性
- 避免表示崩溃,维持特征多样性
数学上,约束条件为:
Σ_j α_{i,j} = 1, ∀i
Σ_i α_{i,j} = 1, ∀j
α_{i,j} ≥ 0
3. 关键算法与实现
3.1 Sinkhorn算法求解
为了满足双随机约束,mHC使用Sinkhorn算法进行归一化:
import torch
def sinkhorn_normalization(W, num_iter=10):
"""
Sinkhorn算法归一化权重矩阵W为双随机矩阵
Args:
W: [n, n] 权重矩阵
num_iter: 迭代次数
Returns:
双随机矩阵
"""
for _ in range(num_iter):
# 行归一化
W = W / W.sum(dim=1, keepdim=True)
# 列归一化
W = W / W.sum(dim=0, keepdim=True)
return W
3.2 mHC层实现
class MHCLayer(torch.nn.Module):
def __init__(self, num_layers, d_model):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
# 可学习的超连接权重
self.alpha = torch.nn.Parameter(torch.randn(num_layers, num_layers))
self.beta = torch.nn.Parameter(torch.randn(num_layers, num_layers))
# 初始化为接近双随机矩阵
with torch.no_grad():
self.alpha.data.fill_(1.0 / num_layers)
self.beta.data.fill_(1.0 / num_layers)
def forward(self, x_list):
"""
Args:
x_list: 层激活值列表 [x_0, x_1, ..., x_{L-1}]
Returns:
mHC输出
"""
# 应用流形约束
alpha_constrained = sinkhorn_normalization(self.alpha.softmax(dim=1))
beta_constrained = sinkhorn_normalization(self.beta.softmax(dim=1))
outputs = []
for i in range(len(x_list)):
# 计算超连接输出
weighted_transformations = []
weighted_inputs = []
for j in range(len(x_list)):
weighted_transformations.append(alpha_constrained[i, j] * x_list[j])
weighted_inputs.append(beta_constrained[i, j] * x_list[j])
y = torch.stack(weighted_transformations).sum(dim=0) + \
torch.stack(weighted_inputs).sum(dim=0)
outputs.append(y)
return outputs
3.3 训练稳定性监控
def train_with_mhc(model, dataloader, optimizer):
"""
带mHC的训练循环,监控训练稳定性
"""
model.train()
loss_history = []
for batch in dataloader:
optimizer.zero_grad()
# 前向传播
outputs = model(batch)
loss = compute_loss(outputs, batch.targets)
# 反向传播
loss.backward()
# 梯度裁剪(mHC已经提升了稳定性,但仍需基础保护)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
loss_history.append(loss.item())
# 分析训练稳定性
loss_gap = max(loss_history[-100:]) - min(loss_history[-100:])
return {
"final_loss": loss_history[-1],
"loss_gap": loss_gap,
"convergence": loss_gap < 0.1 # 收敛阈值
}
4. 技术优势与贡献
4.1 主要优势
-
训练稳定性显著提升:
- 在超过1000层的超深Transformer上稳定训练
- 损失收敛更平滑,避免振荡
-
表示多样性保持:
- 流形约束防止表示崩溃
- 深层特征保持丰富的语义信息
-
通用性强:
- 可应用于各种Transformer架构
- 不依赖特定模型设计
-
可扩展性好:
- 随着模型深度增加,优势更加明显
- 适合未来更大规模的模型
4.2 实验结果
根据DeepSeek团队的实验,mHC在多个基准测试上表现优异:
| 模型 | 架构 | MMLU | GSM8K | 训练稳定性 |
|---|---|---|---|---|
| Baseline | Transformer-1000 | 68.5% | 72.3% | 不稳定 |
| HC | HyperConnections-1000 | 72.1% | 76.8% | 一般 |
| mHC | mHC-1000 | 75.8% | 81.2% | 稳定 |
5. 适用场景与案例
5.1 适用场景
-
超深语言模型:
- 需要超过100层的大模型
- 要求训练稳定性的生产环境
-
多模态模型:
- 文本-图像-音频融合模型
- 特征空间维度高、网络深
-
持续学习:
- 模型需要不断学习新知识
- 避免表示崩溃
5.2 实际案例
案例:构建2000层超深Transformer
class UltraDeepTransformer(torch.nn.Module):
def __init__(self, num_layers=2000, d_model=768):
super().__init__()
self.num_layers = num_layers
self.d_model = d_model
# 嵌入层
self.embedding = torch.nn.Embedding(50000, d_model)
# 2000层Transformer
self.layers = torch.nn.ModuleList([
TransformerBlock(d_model) for _ in range(num_layers)
])
# mHC连接
self.mhc = MHCLayer(num_layers, d_model)
# 输出层
self.output = torch.nn.Linear(d_model, 50000)
def forward(self, x):
# 初始化列表
activations = []
# 嵌入
x = self.embedding(x)
activations.append(x)
# 通过各层
for layer in self.layers:
x = layer(x)
activations.append(x)
# 应用mHC
activations = self.mhc(activations)
# 使用最后一层
output = self.output(activations[-1])
return output
6. 相关论文与参考资料
-
核心论文:
- "mHC: Manifold-Constrained Hyper-Connections", arXiv:2512.24880 (2026)
- 作者:Zhenda Xie, Yixuan Wei, Huanqi Cao等(DeepSeek团队)
-
相关工作:
- "DeepMind's HyperConnections" (2025)
- "Training Deep Networks with Layer-wise Connections"
-
代码资源:
- DeepSeek官方GitHub实现
- PyTorch社区开源实现
技术二:FlashAttention 2 —— IO感知的注意力机制优化
1. 技术背景与动机
自注意力机制(Self-Attention)是Transformer架构的核心,但其计算复杂度和内存占用一直是瓶颈:
- 时间复杂度:O(n²),其中n是序列长度
- 空间复杂度:O(n²),需要存储完整的注意力矩阵
当处理长序列(如64K token)时,内存占用成为主要障碍。例如,一个64K × 64K的注意力矩阵,仅存储就需要约16GB显存。
FlashAttention由斯坦福团队提出,通过IO感知的优化,在不改变模型输出的前提下,大幅降低显存占用并加速计算。FlashAttention 2进一步优化了并行策略,实现约2倍的性能提升。
2. 核心概念与原理
2.1 问题本质
传统注意力计算流程:
def standard_attention(Q, K, V):
# 1. 计算注意力分数 (O(n²) 显存)
S = Q @ K.T / sqrt(d_k)
# 2. Softmax归一化 (O(n²) 显存)
P = softmax(S, dim=-1)
# 3. 加权求和 (O(n²) 显存)
O = P @ V
return O
问题:需要显式存储S和P,每个都是n×n矩阵。
2.2 FlashAttention的核心洞察
关键洞察:Softmax可以通过增量方式在线计算,无需存储完整矩阵
FlashAttention采用了分块(Tiling)策略+在线Softmax算法:
- 分块计算:将Q、K、V分成小块,每次只处理一块
- 在线Softmax:维护状态三元组
(m, l, O),逐步更新 - 核融合:整个计算流程融合在一个CUDA Kernel中
3. 关键算法与实现
3.1 在线Softmax算法
对于每个查询向量,维护状态:
m:当前已处理块的最大logitl:当前已处理块的指数和(归一化因子)O:当前累积输出
状态更新公式:
def update_online_softmax(m_old, l_old, O_old, S_new, V_new):
"""
在线Softmax状态更新
Args:
m_old: 旧的最大值
l_old: 旧的指数和
O_old: 旧的输出
S_new: 新块的注意力分数
V_new: 新块的值向量
Returns:
(m_new, l_new, O_new): 更新后的状态
"""
m_new_block = torch.max(S_new, dim=-1, keepdim=True)[0] # 新块的最大值
# 更新全局最大值
m_new = torch.maximum(m_old, m_new_block)
# 计算修正因子
alpha = torch.exp(m_old - m_new)
beta = torch.exp(m_new_block - m_new)
# 更新归一化因子
P_new_block = torch.exp(S_new - m_new_block) # 新块的未归一化权重
l_new_block = torch.sum(P_new_block, dim=-1, keepdim=True)
l_new = l_old * alpha + l_new_block * beta
# 更新输出
O_new = (l_old * alpha / l_new) * O_old + \
(beta / l_new) * (P_new_block @ V_new)
return m_new, l_new, O_new
3.2 FlashAttention 2 算法流程
def flash_attention_v2(Q, K, V, block_size_q=64, block_size_k=64):
"""
FlashAttention v2 实现
Args:
Q: [seq_len, d_model] 查询矩阵
K: [seq_len, d_model] 键矩阵
V: [seq_len, d_model] 值矩阵
block_size_q: Q的分块大小
block_size_k: K/V的分块大小
Returns:
O: [seq_len, d_model] 注意力输出
"""
seq_len, d_model = Q.shape
# 初始化输出
O = torch.zeros_like(Q)
# 将K和V分块
num_blocks_k = (seq_len + block_size_k - 1) // block_size_k
K_blocks = [K[i*block_size_k:(i+1)*block_size_k] for i in range(num_blocks_k)]
V_blocks = [V[i*block_size_k:(i+1)*block_size_k] for i in range(num_blocks_k)]
# 初始化每个查询的状态
m = torch.full((seq_len, 1), -float('inf'), device=Q.device)
l = torch.zeros((seq_len, 1), device=Q.device)
O_acc = torch.zeros_like(Q)
# 外层循环:遍历K/V块(FlashAttention 2的改进)
for k_idx, K_block in enumerate(K_blocks):
V_block = V_blocks[k_idx]
# 将Q分块(每个线程块处理一个Q块)
num_blocks_q = (seq_len + block_size_q - 1) // block_size_q
for q_idx in range(num_blocks_q):
q_start = q_idx * block_size_q
q_end = min((q_idx + 1) * block_size_q, seq_len)
Q_block = Q[q_start:q_end]
# 在SRAM中计算局部注意力
S_block = Q_block @ K_block.T / torch.sqrt(torch.tensor(d_model, dtype=torch.float32))
# 更新在线Softmax状态
m_block, l_block, O_block = update_online_softmax(
m[q_start:q_end],
l[q_start:q_end],
O_acc[q_start:q_end],
S_block,
V_block
)
# 更新全局状态
m[q_start:q_end] = m_block
l[q_start:q_end] = l_block
O_acc[q_start:q_end] = O_block
# 最终归一化
O = O_acc / l
return O
3.3 使用FlashAttention 2的示例
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载支持Flash Attention 2的模型
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2", # 关键配置
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 测试长序列处理
text = "这是一个非常长的文本输入..." * 1000 # 约32K tokens
inputs = tokenizer(text, return_tensors="pt", truncation=False)
# Flash Attention 2会自动处理
with torch.no_grad():
outputs = model(**inputs)
print(f"序列长度: {inputs['input_ids'].shape[1]}")
print(f"输出形状: {outputs.logits.shape}")
4. FlashAttention 1 vs 2 的并行优化
4.1 FlashAttention 1的并行策略
for i in range(num_q_blocks): # 外层:Q块
for j in range(num_kv_blocks): # 内层:K/V块
# 每个线程块处理一个Q_i块,串行处理所有K/V块
缺点:
- 并行度受限于Q的块数(⌈n/B_r⌉)
- Q_i块需要重复加载
4.2 FlashAttention 2的并行策略
for j in range(num_kv_blocks): # 外层:K/V块
for i in range(num_q_blocks): # 内层:Q块
# 每个线程块处理一个K_j/V_j块,并行处理所有Q块
优点:
- 并行度提升到O(⌈n/B_c⌉ × ⌈n/B_r⌉)
- K_j和V_j块可以被多个线程共享
- 更好的数据局部性
5. 技术优势与性能对比
5.1 主要优势
-
显存占用降低:
- 从O(n²)降低到O(n × d_model)
- 支持32K、64K甚至更长的序列
-
计算速度提升:
- FlashAttention 2相比标准Attention快2-4倍
- 相比FlashAttention 1快约2倍
-
精确计算:
- 不是近似算法,数学上完全等价
- 无精度损失
-
易于集成:
- 只需设置
attn_implementation="flash_attention_2" - 与HuggingFace Transformers无缝集成
- 只需设置
5.2 性能对比表
| 方法 | 序列长度 | 显存占用 | 速度 | 精度 |
|---|---|---|---|---|
| 标准Attention | 32K | ~32GB | 1x | 100% |
| FlashAttention 1 | 32K | ~8GB | 2.3x | 100% |
| FlashAttention 2 | 32K | ~8GB | 4.5x | 100% |
| FlashAttention 2 | 64K | ~16GB | 3.8x | 100% |
6. 适用场景与实战案例
6.1 适用场景
-
长文本处理:
- 文档摘要、长文档问答
- 代码生成与理解
-
大模型训练:
- 大批量训练
- 长序列预训练
-
低显存环境:
- 消费级GPU(如RTX 3090、4090)
- 多GPU训练的梯度累积优化
6.2 实战案例:构建长文档问答系统
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
class LongDocumentQA:
def __init__(self, model_name="microsoft/Phi-3-mini-128k-instruct"):
self.model = AutoModelForQuestionAnswering.from_pretrained(
model_name,
attn_implementation="flash_attention_2", # 使用FlashAttention 2
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def answer(self, document, question, max_length=128000):
"""
长文档问答
Args:
document: 长文档文本
question: 问题
max_length: 最大序列长度
Returns:
answer: 答案
"""
# 组合输入
inputs = self.tokenizer(
question,
document,
return_tensors="pt",
max_length=max_length,
truncation=True
)
# 使用FlashAttention 2加速
with torch.no_grad():
outputs = self.model(**inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
# 提取答案
start_idx = torch.argmax(start_logits)
end_idx = torch.argmax(end_logits)
answer_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
answer = self.tokenizer.decode(answer_tokens, skip_special_tokens=True)
return answer
# 使用示例
qa_system = LongDocumentQA()
document = """这里放入一个非常长的文档...""" * 5000 # 约100K tokens
question = "文档的主要观点是什么?"
answer = qa_system.answer(document, question)
print(f"答案: {answer}")
6.3 实战案例:高效微调长序列模型
from transformers import (
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding
)
# 使用FlashAttention 2微调
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased",
attn_implementation="flash_attention_2",
num_labels=2
)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=32, # 可以更大,因为FlashAttention降低显存
per_device_eval_batch_size=64,
num_train_epochs=3,
learning_rate=2e-5,
fp16=True, # 使用混合精度训练
gradient_accumulation_steps=4, # 进一步扩大有效批量
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)
trainer.train()
7. 相关论文与参考资料
-
核心论文:
- "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (ICML 2023)
- "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (2024)
-
作者:
- Tri Dao, 斯坦福大学
- Hazy Research团队
-
代码资源:
- GitHub: https://github.com/Dao-AILab/flash-attention
- HuggingFace Transformers原生支持
-
教程与博客:
- [AIInfra] FlashAttention 深度解析
- Markaicode FlashAttention 2教程
总结
今天的两个技术——DeepSeek mHC和FlashAttention 2——代表了AI训练与推理领域的两个重要方向:
- 架构创新:mHC通过流形约束解决超深网络的训练稳定性问题,为未来更大规模的模型奠定基础
- 系统优化:FlashAttention 2通过IO感知的算法设计和并行优化,在不牺牲精度的前提下大幅提升效率
这两项技术都体现了理论创新与工程实践的深度融合,是构建下一代AI系统的关键基石。

浙公网安备 33010602011771号