模型微调

import argparse
import os
from functools import partial

from config import data_path, model_path
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)


def build_prompt_completion(example):
    conversations = example.get("conversations", [])
    if not isinstance(conversations, list) or len(conversations) < 2:
        return {"prompt": "", "completion": ""}

    prompt_parts = []
    for turn in conversations[:-1]:
        role = turn.get("from", "").lower()
        if role in {"human", "user"}:
            prompt_parts.append(f"Human: {turn.get('value', '').strip()}")
        else:
            prompt_parts.append(f"Assistant: {turn.get('value', '').strip()}")

    prompt = "\n".join(prompt_parts).strip()
    if prompt and not prompt.endswith("\n"):
        prompt += "\n"
    if prompt:
        prompt += "Assistant: "
    else:
        prompt = "Assistant: "

    final_turn = conversations[-1]
    completion = final_turn.get("value", "").strip()
    return {"prompt": prompt, "completion": completion}


def tokenize_example(example, tokenizer, max_length):
    prompt_completion = build_prompt_completion(example)
    prompt = prompt_completion["prompt"]
    completion = prompt_completion["completion"]
    if not completion:
        completion = ""

    text = prompt + completion + (tokenizer.eos_token or "")
    tokenized = tokenizer(text, truncation=True, max_length=max_length, padding=False)
    prompt_tokens = tokenizer(
        prompt, truncation=True, max_length=max_length, add_special_tokens=False
    )

    labels = tokenized["input_ids"].copy()
    prompt_len = len(prompt_tokens["input_ids"])
    if prompt_len > 0:
        labels[:prompt_len] = [-100] * prompt_len
    tokenized["labels"] = labels
    return tokenized


def main():
    import torch

    parser = argparse.ArgumentParser(
        description="Fine-tune a causal LM on ShareGPT-style QA data with LoRA."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./fine_tuned_model",
        help="Where to save the fine-tuned model.",
    )
    parser.add_argument(
        "--num_train_epochs", type=int, default=3, help="Number of training epochs."
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=1,
        help="Batch size per device for training.",
    )
    parser.add_argument(
        "--learning_rate", type=float, default=2e-5, help="Learning rate."
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.01, help="Weight decay."
    )
    parser.add_argument(
        "--max_length",
        type=int,
        default=1024,
        help="Maximum sequence length for tokenization.",
    )
    parser.add_argument(
        "--eval_split_percentage",
        type=int,
        default=5,
        help="Percentage of data used for evaluation.",
    )
    parser.add_argument(
        "--save_steps", type=int, default=10000, help="Save checkpoint every X steps."
    )
    parser.add_argument(
        "--logging_steps", type=int, default=100, help="Log every X steps."
    )
    parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank.")
    parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha.")
    parser.add_argument(
        "--lora_dropout", type=float, default=0.05, help="LoRA dropout."
    )
    args = parser.parse_args()

    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token or tokenizer.vocab.get(
            "", tokenizer.bos_token
        )

    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.resize_token_embeddings(len(tokenizer))

    # Find target modules dynamically
    target_modules = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) and "language_model" in name:
            if any(
                keyword in name
                for keyword in [
                    "q_proj",
                    "k_proj",
                    "v_proj",
                    "o_proj",
                    "gate_proj",
                    "up_proj",
                    "down_proj",
                ]
            ):
                target_modules.append(name)

    print(
        f"Found {len(target_modules)} target modules: {target_modules[:5]}..."
    )  # Show first 5

    # Configure LoRA
    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        target_modules=target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    dataset = load_dataset("json", data_files=data_path)
    if "train" not in dataset:
        dataset = dataset["train"]

    if args.eval_split_percentage > 0:
        if isinstance(dataset, dict):
            dataset = dataset["train"].train_test_split(
                test_size=args.eval_split_percentage / 100, seed=42
            )
        else:
            dataset = dataset.train_test_split(
                test_size=args.eval_split_percentage / 100, seed=42
            )
    else:
        dataset = {"train": dataset}

    preprocess_fn = partial(
        tokenize_example, tokenizer=tokenizer, max_length=args.max_length
    )
    tokenized_dataset = dataset.map(
        preprocess_fn, remove_columns=dataset["train"].column_names
    )

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_train_batch_size,
        eval_strategy="steps" if args.eval_split_percentage > 0 else "no",
        eval_steps=args.save_steps,
        save_steps=args.save_steps,
        logging_steps=args.logging_steps,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        save_total_limit=2,
        fp16=torch.cuda.is_available(),
        report_to="none",
        push_to_hub=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset.get("test"),
        processing_class=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    # Optionally merge LoRA weights
    merged_model = model.merge_and_unload()
    merged_model.save_pretrained(os.path.join(args.output_dir, "merged"))
    tokenizer.save_pretrained(os.path.join(args.output_dir, "merged"))

    print(f"LoRA adapters saved to {args.output_dir}")
    print(f"Merged model saved to {os.path.join(args.output_dir, 'merged')}")


if __name__ == "__main__":
    main()


posted @ 2026-05-12 10:28  明素心  阅读(4)  评论(0)    收藏  举报