尝试上手大模型的有监督微调

总览

这个文章留下微调 Gemma-2b-it 模型的记录。以很简单的、只有一句话的微调为例。

本文最终的目标:问模型 “微调测试”,模型回答 “我学会了”。

准备

加载模型和分词器

tokenizer = AutoTokenizer.from_pretrained(
    "google/gemma-2b-it",
    cache_dir="/mnt/g/models/cache",
    token=token,
)

model = AutoModelForCausalLM.from_pretrained(
    "/mnt/g/models/gemma-1.1-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

训练集

对话训练集格式通常是 json,网上下载的后缀一般为 .jsonl,长这样(来源):

{"conversation": [{"role": "human", "message": "中国移动营销行来发展报告alink"}, {"role": "assistant", "message": "王大姐,打字细心一点"}], "category": null, "data_source": "weibo"}
{"conversation": [{"role": "human", "message": "小马也疯狂------地位之争。"}, {"role": "assistant", "message": "于老师不给劝劝架么告诉他们再挣也不是老大"}], "category": null, "data_source": "weibo"}
{"conversation": [{"role": "human", "message": "那些年,我们一起偷看过的电视。「暴走漫画」"}, {"role": "assistant", "message": "真不愧是这么走出来的爹·······"}], "category": null, "data_source": "weibo"}
···

使用 HuggingFace 的 dataset 库可加载这个文件。

from datasets import load_dataset

dataset_fromfile = load_dataset(
    "json",
    data_files=r"./test_input_dataset.jsonl",
    split="train",
)

本文只用一个对话的微调进行示例,所以用以下方法准备训练集 chat

chat = {
    "conversation": [
        [
            {
                "role": "user",
                "content": "微调测试",
            },
            {
                "role": "assistant",
                "content": "我学会了",
            },
        ],
    ]
}
chat = Dataset.from_dict(chat)

TRL 库

借助 HuggingFace 的 TRL 库进行微调。

pip install trl

接下来介绍 TRL 的两个类,以及一个特殊的函数 formatting_func

SFTTrainer

SFT 是 “有监督微调” 的缩写(Supervised Finetuning)。

SFTTrainer 继承于 transformers.Trainer。借助 SFTTrainer,可以封装一个专用于语言模型有监督微调的类。

DataCollatorForCompletionOnlyLM

借助 DataCollatorForCompletionOnlyLM,可以仅对需要生成的 prompt 训练。即,只对模型生成的 token 部分计算 loss。

其他细节不必深究,只需要知道 SFTTrainer 需要一个 data_collator 对象,将语料转换成适合训练的形式。

response_template = "<start_of_turn>model\n"

collator = DataCollatorForCompletionOnlyLM(
    tokenizer=tokenizer,
    response_template=response_template,
)

可见,实例化这个 collator 需要传入 tokenizerresponse_template

在 Gemma 中,模型的回答都接在 "<start_of_turn>model\n" 之后,所以传入这个 response_template 告诉 collator 从这里开始标记需要训练的部分。

formatting_func

语料需要先转换成某种字符串,再转换成 token,才能输入到模型。

为了将训练语料正确处理成符合预训练模型规则的字符串,SFTTrainer 需要传入一个处理函数。

def formatting_prompts_func(example):
    output_texts = []
    for c in example["conversation"]:
        text = tokenizer.apply_chat_template(c, tokenize=False) + tokenizer.eos_token
        output_texts.append(text)
    return output_texts

这里取了巧,借助 tokenizer 自带的 chat_template 转换。

TrainingArguments

需要向 SFTTrainer 传入优化器、学习率等参数。

不必多言,看示例代码。更多可选参数请查阅 HuggingFace 文档。

from transformers import TrainingArguments

args = TrainingArguments(
        per_device_train_batch_size=8,
        num_train_epochs=30,
        learning_rate=2e-5,
        optim="adamw_8bit",
        bf16=True,
        output_dir="/mnt/z/model_test",
        report_to=["tensorboard"],
        logging_steps=1,
)

开始训练

做好一切准备后,就能实例化 SFTrainer 开始训练了。

trainer = SFTTrainer(
    model,
    tokenizer=tokenizer,
    train_dataset=chat,
    max_seq_length=1024,
    args=args,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
    dataset_kwargs={"add_special_tokens": False},  # 特殊 token 已经在 formatting_func 加过了
)
trainer.train()

LoRA

借助 peft 库,只需要封装一遍 model 就能应用 LoRA。

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=512,
    lora_alpha=512,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
)

model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()

接下来向 SFTTrainer 传入这个 model 就行。

测试

我用这段代码测试训练效果:

chat = [
    {
        "role": "user",
        "content": "微调测试",
    },
]
prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_length=100)
print(tokenizer.decode(outputs[0]))

可以看到效果很明显。

<bos><start_of_turn>user
微调测试<end_of_turn>
<start_of_turn>model
我学会了<end_of_turn>
<eos>
posted @ 2024-04-27 00:44  倒地  阅读(57)  评论(0编辑  收藏  举报