这里使用百川本地
from airllm import AirLLMLlama2,AutoModel
MAX_LENGTH = 128
# 默认会去抱脸网下载,很慢,而且需要xx上网
#model = AutoModel.from_pretrained("baichuan-inc/Baichuan2-7B-Base", profiling_mode=True)
# 本地大模型地址
model = AirLLMLlama2(r'D:\\cache\\hub\\Baichuan2-7B\\snapshots\\3db3da5')
input_text = [
# 'What is the capital of China?',
'I like',
]
input_tokens = model.tokenizer(input_text,
return_tensors="pt",
return_attention_mask=False,
truncation=True,
max_length=MAX_LENGTH,
# padding=True
)
generation_output = model.generate(
input_tokens['input_ids'].cuda(),
max_new_tokens=3,
use_cache=True,
return_dict_in_generate=True)
model.tokenizer.decode(generation_output.sequences[0])