mamba改进
NaN 爆炸几乎总是源于循环核心内部的数值溢出。
我们插入的代码片段请求 GRU 风格的 Mamba 使用 4096 维的隐藏状态(d_model = H × W = 64 × 64)。
这非常大:循环矩阵的大小为 4096×4096,每次迭代会多次进行乘法运算,所以即使是默认的 Xavier/正交初始化,激活值也可能在几次迭代后发散,最终得到 inf/NaN。
以下是三种经过验证的解决方案。根据您的内存/速度预算选择最适合的一种;只需在不变的 VAE 循环(z_list … 及之后部分保持不变)之前添加即可。
选项 A(推荐) 投影降维 → Mamba → 投影升维
# ------------- 每个时间步的跨通道混合 -------------
B, T, C, H, W = inputs['atmos'].shape          # C = 9,H=W=64
HW      = H * W                                # 4096
D_mamba = self.embed_dim                       # 例如 128(更安全)
# (B*T, 9, HW) — 展平空间网格
tokens   = inputs['atmos'].reshape(B * T, C, HW)
# ① 将 4096 维向量投影到 128
tokens   = self.down_proj(tokens)              # (B*T, 9, D_mamba)
# ② 通道-wise Mamba
tokens   = self.channel_mamba(tokens)          # (B*T, 9, D_mamba)
# ③ 将维度投影回 4096,以便 VAE 见到相同的形状
tokens   = self.up_proj(tokens)                # (B*T, 9, HW)
# 重塑为原始的 4D 张量
atmos    = tokens.view(B * T, C, H, W)         # <- 与之前相同
在 __init__ 中一次添加小型层:
self.down_proj = nn.Linear(H*W, embed_dim, bias=False)
self.channel_mamba = SimpleMambaBlock(d_model=embed_dim)
self.up_proj   = nn.Linear(embed_dim, H*W, bias=False)
循环核心现在以 d_model = 128 运行,因此它更稳定且在 FLOPs 和内存方面便宜 32 倍。
选项 B 层归一化 → Mamba → 层归一化(原地稳定化)
保持 d_model = 4096,但用归一化包装输入/输出,这样 GRU 永远不会看到无界值:
normed  = self.pre_norm(atmos_flat)            # 对最后一个维度进行层归一化
mixed   = self.channel_mamba(normed)           # 仍然是 4096 维
atmos   = self.post_norm(mixed).view(B*T, C, H, W)
其中
self.pre_norm  = nn.LayerNorm(H*W)
self.post_norm = nn.LayerNorm(H*W)
这是最轻量级的代码更改,但您仍然要为 4096 维付出内存代价,且 GRU 可能仍然很慢。
选项 C 将每个像素视为时间轴(d_model = 9)
如果您真正的目标只是让九个大气通道相互作用,而实际上不需要 4096 维的隐藏状态,则可以交换序列/特征:
# (B*T, HW, 9) — 序列长度 = 4096,特征维度 = 9
tokens  = inputs['atmos'].reshape(B*T, C, HW).transpose(1, 2)
# GRU,`d_model = 9`   (在 `__init__` 中定义一次)
tokens  = self.pixel_mamba(tokens)             # (B*T, HW, 9)
# 返回到 (B*T, 9, H, W)
atmos   = tokens.transpose(1, 2).view(B*T, C, H, W)
这将重维度推到序列长度,因此 d_model = 9 非常稳定。FLOPs 随 4096 步线性增长,但每一步都很小;总体而言,它仍然比 4096 维的 GRU 更快。
避免 NaN 的快速检查清单,无论您选择哪种方法
- 初始化 — 使用正交或 Xavier 均匀分布初始化 GRU 权重;PyTorch 的默认值是好的,但避免使用跳过初始化的自定义种子。
 - 输入范围 — 您已经将每个通道归一化到 
[0, 1];这是好的。如果您采用选项 A/B,减去 0.5 以使其以 0 为中心。 - 混合精度 — 如果您使用 AMP,请将 Mamba 块保持在 
float32(with torch.cuda.amp.autocast(enabled=False): …)以获得额外的安全性。 - 梯度裁剪 — 在 
loss.backward()后使用torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)永远不会出错。 
实现选项 A 或 B,NaN 应该会消失,同时代码 之后 您的注释块(z_list … 部分)保持 100% 不变。
                    
                
                
            
        
浙公网安备 33010602011771号