import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer
import warnings
# 临时屏蔽 FutureWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
# ===================== 核心配置(全部改为本地路径) =====================
# 本地Mamba2-780m权重文件夹路径
LOCAL_MAMBA_PATH = "mamba2-780m"
# 本地gpt-neox-20b分词器文件夹路径
LOCAL_TOKENIZER_PATH = "gpt-neox-20b"
# 设备配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 精度配置
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
# ===================== 加载本地分词器 =====================
tokenizer = AutoTokenizer.from_pretrained(
LOCAL_TOKENIZER_PATH,
local_files_only=True, # 本地加载
padding_side="left" # 左填充在初始化时设置
)
# 官方要求的pad_token配置
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
# ===================== 加载本地Mamba2模型 =====================
model = MambaLMHeadModel.from_pretrained(
LOCAL_MAMBA_PATH,
device=DEVICE,
dtype=DTYPE
)
model.eval() # 推理模式
# ===================== 推理函数(对齐Mamba官方generate参数) =====================
def mamba_inference(prompt, max_new_tokens=100, temperature=1.0, top_k=1, top_p=0.0, min_p=0.0):
"""
对齐Mamba官方generate方法的推理函数
:param prompt: 输入提示文本
:param max_new_tokens: 期望生成的新token数(转换为max_length传入)
:param temperature: 生成随机性(Mamba默认1.0)
:param top_k: 采样top_k(Mamba默认1)
:param top_p: 采样top_p(Mamba默认0.0)
:param min_p: 采样min_p(Mamba默认0.0)
:return: 生成的文本
"""
# 1. 编码输入文本
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048 # 输入最大长度限制
).to(DEVICE)
# 2. 计算max_length(核心:Mamba需要总长度,而非新增长度)
# max_length = 输入token长度 + 期望生成的新token长度
input_length = inputs["input_ids"].shape[1]
max_length = input_length + max_new_tokens
# 防止总长度超过模型最大上下文(Mamba2-780m默认2048)
max_length = min(max_length, 2048)
# 3. 调用Mamba官方generate方法(严格对齐参数)
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"], # 必填:输入token
max_length=max_length, # 必填:总生成长度(Mamba强制要求)
temperature=temperature, # 可选:随机性
top_k=top_k, # 可选:top_k采样
top_p=top_p, # 可选:top_p采样
min_p=min_p, # 可选:min_p采样
)
# 4. 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取仅生成的部分(去除输入prompt)
generated_text = response[len(prompt):].strip() if response.startswith(prompt) else response
return generated_text
# ===================== 测试推理 =====================
if __name__ == "__main__":
prompt = "请解释Mamba模型相比Transformer的核心优势是什么?"
print(f"输入提示:{prompt}\n")
try:
# 注意:Mamba的generate方法默认top_k=1(确定性生成),如需随机性可调大temperature/top_k
result = mamba_inference(
prompt,
max_new_tokens=200,
temperature=0.7, # 降低随机性
top_k=50, # 开启top_k采样
top_p=0.9 # 开启top_p采样
)
print(f"生成结果:{result}")
except Exception as e:
print(f"推理过程出错:{str(e)}")
# 打印调试信息
print(f"当前设备:{DEVICE}")
print(f"PyTorch版本:{torch.__version__}")
print(f"模型设备:{next(model.parameters()).device}")
import warnings
# 临时屏蔽 FutureWarning
warnings.filterwarnings("ignore", category=FutureWarning)
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer
# ===================== 核心配置(全部改为本地路径) =====================
# 本地Mamba2-780m权重文件夹路径
LOCAL_MAMBA_PATH = "mamba2-780m" # 替换为实际路径
# 本地gpt-neox-20b分词器文件夹路径
LOCAL_TOKENIZER_PATH = "gpt-neox-20b" # 替换为实际路径
# 设备配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 精度配置
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
COMPILE_MODEL = False # 获取隐状态时关闭编译,避免输出被优化
# ===================== 加载本地分词器 =====================
tokenizer = AutoTokenizer.from_pretrained(
LOCAL_TOKENIZER_PATH,
local_files_only=True, # 本地加载
padding_side="left", # 左填充在初始化时设置
)
# 官方要求的pad_token配置
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
# ===================== 加载本地Mamba2模型 =====================
model = MambaLMHeadModel.from_pretrained(
LOCAL_MAMBA_PATH,
device=DEVICE,
dtype=DTYPE,
)
if COMPILE_MODEL and torch.cuda.is_available():
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
model.eval() # 推理模式
# ===================== 获取SSM隐状态函数 =====================
@torch.no_grad() # 彻底关闭梯度计算
@torch.cuda.amp.autocast(dtype=DTYPE) # 混合精度加速
def get_mamba_ssm_hidden_state(prompt, max_new_tokens=1, return_last_token_only=True):
"""
获取Mamba模型处理输入文本后的SSM隐状态
:param prompt: 输入提示文本
:param return_last_token_only: 是否只返回最后一个token的隐状态(默认True)
:return: SSM隐状态张量,shape说明:
- return_last_token_only=True: [hidden_dim] (如780M模型为2048)
- return_last_token_only=False: [seq_len, hidden_dim]
"""
# 1. 编码输入文本
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048 # 输入最大长度限制
).to(DEVICE)
input_ids = inputs["input_ids"]
# 2. 前向传播获取模型输出(包含隐状态)
# MambaLMHeadModel的forward方法返回: (logits, hidden_states)
hidden_states = model.backbone(
input_ids=input_ids,
)
# 4. 处理输出格式
# 去除batch维度(因为batch_size=1)
print(hidden_states.shape)
hidden_states = hidden_states.squeeze(0) # shape: [seq_len, hidden_dim]
if return_last_token_only:
last_token_idx = hidden_states.shape[0] - 1
# 提取最后一个token的隐状态
last_token_hidden = hidden_states[last_token_idx] # shape: [hidden_dim]
return last_token_hidden
else:
# 返回整个序列的隐状态
return hidden_states
# ===================== 测试获取隐状态 =====================
if __name__ == "__main__":
prompt = "请解释Mamba模型相比Transformer的核心优势是什么?"
print(f"输入提示:{prompt}\n")
try:
print("====== 获取SSM隐状态 ======")
import time
start_time = time.time()
# 获取最后一个token的SSM隐状态
last_hidden_state = get_mamba_ssm_hidden_state(prompt)
end_time = time.time()
print(f"获取隐状态耗时:{end_time-start_time:.2f}秒")
print(f"最后一个token的SSM隐状态形状:{last_hidden_state.shape}")
print(f"隐状态数据类型:{last_hidden_state.dtype}")
print(f"隐状态设备:{last_hidden_state.device}")
print(f"\n隐状态前10个值:{last_hidden_state[:10]}")
# 可选:获取整个序列的隐状态
# full_hidden_states = get_mamba_ssm_hidden_state(prompt, return_last_token_only=False)
# print(f"\n整个序列的隐状态形状:{full_hidden_states.shape}")
# print(f"序列长度(token数):{full_hidden_states.shape[0]}")
# print(f"隐状态维度:{full_hidden_states.shape[1]}")
except Exception as e:
print(f"获取隐状态过程出错:{str(e)}")
import traceback
traceback.print_exc()
# 打印调试信息
print(f"\n当前设备:{DEVICE}")
print(f"PyTorch版本:{torch.__version__}")
print(f"模型设备:{next(model.parameters()).device}")