自用共享文档
#mamba模块测试
import torch from mamba_ssm import Mamba batch,length,dim =2,196,64 x=torch.randn(batch,length,dim).to("cuda") model =Mamba( d_model=dim, d_state=16, d_conv=4, expand=2, ).to("cuda") y=model(x) print(y.shape) assert y.shape==x.shape
pip install mamba-ssm --no-cache-dir --verbose

浙公网安备 33010602011771号