JoyBeanRobber

导航

mamba预训练权重本地加载代码

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer
import warnings

# 临时屏蔽 FutureWarning
warnings.simplefilter(action='ignore', category=FutureWarning)

# ===================== 核心配置(全部改为本地路径) =====================
# 本地Mamba2-780m权重文件夹路径
LOCAL_MAMBA_PATH = "mamba2-780m"  
# 本地gpt-neox-20b分词器文件夹路径
LOCAL_TOKENIZER_PATH = "gpt-neox-20b"  
# 设备配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 精度配置
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

# ===================== 加载本地分词器 =====================
tokenizer = AutoTokenizer.from_pretrained(
    LOCAL_TOKENIZER_PATH,
    local_files_only=True,  # 本地加载
    padding_side="left"     # 左填充在初始化时设置
)
# 官方要求的pad_token配置
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

# ===================== 加载本地Mamba2模型 =====================
model = MambaLMHeadModel.from_pretrained(
    LOCAL_MAMBA_PATH,
    device=DEVICE,
    dtype=DTYPE
)
model.eval()  # 推理模式

# ===================== 推理函数(对齐Mamba官方generate参数) =====================
def mamba_inference(prompt, max_new_tokens=100, temperature=1.0, top_k=1, top_p=0.0, min_p=0.0):
    """
    对齐Mamba官方generate方法的推理函数
    :param prompt: 输入提示文本
    :param max_new_tokens: 期望生成的新token数(转换为max_length传入)
    :param temperature: 生成随机性(Mamba默认1.0)
    :param top_k: 采样top_k(Mamba默认1)
    :param top_p: 采样top_p(Mamba默认0.0)
    :param min_p: 采样min_p(Mamba默认0.0)
    :return: 生成的文本
    """
    # 1. 编码输入文本
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048  # 输入最大长度限制
    ).to(DEVICE)
    
    # 2. 计算max_length(核心:Mamba需要总长度,而非新增长度)
    # max_length = 输入token长度 + 期望生成的新token长度
    input_length = inputs["input_ids"].shape[1]
    max_length = input_length + max_new_tokens
    # 防止总长度超过模型最大上下文(Mamba2-780m默认2048)
    max_length = min(max_length, 2048)

    # 3. 调用Mamba官方generate方法(严格对齐参数)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],  # 必填:输入token
            max_length=max_length,          # 必填:总生成长度(Mamba强制要求)
            temperature=temperature,        # 可选:随机性
            top_k=top_k,                    # 可选:top_k采样
            top_p=top_p,                    # 可选:top_p采样
            min_p=min_p,                    # 可选:min_p采样
        )

    # 4. 解码输出
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # 提取仅生成的部分(去除输入prompt)
    generated_text = response[len(prompt):].strip() if response.startswith(prompt) else response
    return generated_text

# ===================== 测试推理 =====================
if __name__ == "__main__":
    prompt = "请解释Mamba模型相比Transformer的核心优势是什么?"
    print(f"输入提示:{prompt}\n")
    
    try:
        # 注意:Mamba的generate方法默认top_k=1(确定性生成),如需随机性可调大temperature/top_k
        result = mamba_inference(
            prompt,
            max_new_tokens=200,
            temperature=0.7,  # 降低随机性
            top_k=50,         # 开启top_k采样
            top_p=0.9         # 开启top_p采样
        )
        print(f"生成结果:{result}")
    except Exception as e:
        print(f"推理过程出错:{str(e)}")
    
    # 打印调试信息
    print(f"当前设备:{DEVICE}")
    print(f"PyTorch版本:{torch.__version__}")
    print(f"模型设备:{next(model.parameters()).device}")

 

 

 

import warnings

# 临时屏蔽 FutureWarning
warnings.filterwarnings("ignore", category=FutureWarning)

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

# ===================== 核心配置(全部改为本地路径) =====================
# 本地Mamba2-780m权重文件夹路径
LOCAL_MAMBA_PATH = "mamba2-780m"  # 替换为实际路径
# 本地gpt-neox-20b分词器文件夹路径
LOCAL_TOKENIZER_PATH = "gpt-neox-20b"  # 替换为实际路径
# 设备配置
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 精度配置
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

COMPILE_MODEL = False  # 获取隐状态时关闭编译,避免输出被优化

# ===================== 加载本地分词器 =====================
tokenizer = AutoTokenizer.from_pretrained(
    LOCAL_TOKENIZER_PATH,
    local_files_only=True,  # 本地加载
    padding_side="left",    # 左填充在初始化时设置
)
# 官方要求的pad_token配置
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

# ===================== 加载本地Mamba2模型 =====================
model = MambaLMHeadModel.from_pretrained(
    LOCAL_MAMBA_PATH,
    device=DEVICE,
    dtype=DTYPE,
)

if COMPILE_MODEL and torch.cuda.is_available():
    model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

model.eval()  # 推理模式

# ===================== 获取SSM隐状态函数 =====================
@torch.no_grad()  # 彻底关闭梯度计算
@torch.cuda.amp.autocast(dtype=DTYPE)  # 混合精度加速
def get_mamba_ssm_hidden_state(prompt, max_new_tokens=1, return_last_token_only=True):
    """
    获取Mamba模型处理输入文本后的SSM隐状态
    :param prompt: 输入提示文本
    :param return_last_token_only: 是否只返回最后一个token的隐状态(默认True)
    :return: SSM隐状态张量,shape说明:
             - return_last_token_only=True: [hidden_dim] (如780M模型为2048)
             - return_last_token_only=False: [seq_len, hidden_dim]
    """
    # 1. 编码输入文本
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048  # 输入最大长度限制
    ).to(DEVICE)
    
    input_ids = inputs["input_ids"]
    
    # 2. 前向传播获取模型输出(包含隐状态)
    # MambaLMHeadModel的forward方法返回: (logits, hidden_states)
    hidden_states = model.backbone(
        input_ids=input_ids,
    )
    
    # 4. 处理输出格式
    # 去除batch维度(因为batch_size=1)
    print(hidden_states.shape)
    hidden_states = hidden_states.squeeze(0)  # shape: [seq_len, hidden_dim]
    
    if return_last_token_only:
        last_token_idx = hidden_states.shape[0] - 1
        
        # 提取最后一个token的隐状态
        last_token_hidden = hidden_states[last_token_idx]  # shape: [hidden_dim]
        return last_token_hidden
    else:
        # 返回整个序列的隐状态
        return hidden_states

# ===================== 测试获取隐状态 =====================
if __name__ == "__main__":
    prompt = "请解释Mamba模型相比Transformer的核心优势是什么?"
    print(f"输入提示:{prompt}\n")
    
    try:
        print("====== 获取SSM隐状态 ======")
        import time
        start_time = time.time()
        
        # 获取最后一个token的SSM隐状态
        last_hidden_state = get_mamba_ssm_hidden_state(prompt)
        
        end_time = time.time()
        print(f"获取隐状态耗时:{end_time-start_time:.2f}秒")
        print(f"最后一个token的SSM隐状态形状:{last_hidden_state.shape}")
        print(f"隐状态数据类型:{last_hidden_state.dtype}")
        print(f"隐状态设备:{last_hidden_state.device}")
        print(f"\n隐状态前10个值:{last_hidden_state[:10]}")
        
        # 可选:获取整个序列的隐状态
        # full_hidden_states = get_mamba_ssm_hidden_state(prompt, return_last_token_only=False)
        # print(f"\n整个序列的隐状态形状:{full_hidden_states.shape}")
        # print(f"序列长度(token数):{full_hidden_states.shape[0]}")
        # print(f"隐状态维度:{full_hidden_states.shape[1]}")
        
    except Exception as e:
        print(f"获取隐状态过程出错:{str(e)}")
        import traceback
        traceback.print_exc()
    
    # 打印调试信息
    print(f"\n当前设备:{DEVICE}")
    print(f"PyTorch版本:{torch.__version__}")
    print(f"模型设备:{next(model.parameters()).device}")

 

posted on 2026-03-08 19:00  欢乐豆掠夺者  阅读(3)  评论(0)    收藏  举报