自用共享文档

#mamba模块测试
import torch
from mamba_ssm import Mamba batch,length,dim =2,196,64 x=torch.randn(batch,length,dim).to("cuda") model =Mamba( d_model=dim, d_state=16, d_conv=4, expand=2, ).to("cuda") y=model(x) print(y.shape) assert y.shape==x.shape

 pip install mamba-ssm --no-cache-dir --verbose

posted @ 2026-01-21 16:28  小花护符  阅读(3)  评论(0)    收藏  举报