import argparse
import os
from functools import partial
from config import data_path, model_path
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
def build_prompt_completion(example):
conversations = example.get("conversations", [])
if not isinstance(conversations, list) or len(conversations) < 2:
return {"prompt": "", "completion": ""}
prompt_parts = []
for turn in conversations[:-1]:
role = turn.get("from", "").lower()
if role in {"human", "user"}:
prompt_parts.append(f"Human: {turn.get('value', '').strip()}")
else:
prompt_parts.append(f"Assistant: {turn.get('value', '').strip()}")
prompt = "\n".join(prompt_parts).strip()
if prompt and not prompt.endswith("\n"):
prompt += "\n"
if prompt:
prompt += "Assistant: "
else:
prompt = "Assistant: "
final_turn = conversations[-1]
completion = final_turn.get("value", "").strip()
return {"prompt": prompt, "completion": completion}
def tokenize_example(example, tokenizer, max_length):
prompt_completion = build_prompt_completion(example)
prompt = prompt_completion["prompt"]
completion = prompt_completion["completion"]
if not completion:
completion = ""
text = prompt + completion + (tokenizer.eos_token or "")
tokenized = tokenizer(text, truncation=True, max_length=max_length, padding=False)
prompt_tokens = tokenizer(
prompt, truncation=True, max_length=max_length, add_special_tokens=False
)
labels = tokenized["input_ids"].copy()
prompt_len = len(prompt_tokens["input_ids"])
if prompt_len > 0:
labels[:prompt_len] = [-100] * prompt_len
tokenized["labels"] = labels
return tokenized
def main():
import torch
parser = argparse.ArgumentParser(
description="Fine-tune a causal LM on ShareGPT-style QA data with LoRA."
)
parser.add_argument(
"--output_dir",
type=str,
default="./fine_tuned_model",
help="Where to save the fine-tuned model.",
)
parser.add_argument(
"--num_train_epochs", type=int, default=3, help="Number of training epochs."
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=1,
help="Batch size per device for training.",
)
parser.add_argument(
"--learning_rate", type=float, default=2e-5, help="Learning rate."
)
parser.add_argument(
"--weight_decay", type=float, default=0.01, help="Weight decay."
)
parser.add_argument(
"--max_length",
type=int,
default=1024,
help="Maximum sequence length for tokenization.",
)
parser.add_argument(
"--eval_split_percentage",
type=int,
default=5,
help="Percentage of data used for evaluation.",
)
parser.add_argument(
"--save_steps", type=int, default=10000, help="Save checkpoint every X steps."
)
parser.add_argument(
"--logging_steps", type=int, default=100, help="Log every X steps."
)
parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank.")
parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha.")
parser.add_argument(
"--lora_dropout", type=float, default=0.05, help="LoRA dropout."
)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.vocab.get(
"", tokenizer.bos_token
)
model = AutoModelForCausalLM.from_pretrained(model_path)
model.resize_token_embeddings(len(tokenizer))
# Find target modules dynamically
target_modules = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "language_model" in name:
if any(
keyword in name
for keyword in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
):
target_modules.append(name)
print(
f"Found {len(target_modules)} target modules: {target_modules[:5]}..."
) # Show first 5
# Configure LoRA
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=target_modules,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
dataset = load_dataset("json", data_files=data_path)
if "train" not in dataset:
dataset = dataset["train"]
if args.eval_split_percentage > 0:
if isinstance(dataset, dict):
dataset = dataset["train"].train_test_split(
test_size=args.eval_split_percentage / 100, seed=42
)
else:
dataset = dataset.train_test_split(
test_size=args.eval_split_percentage / 100, seed=42
)
else:
dataset = {"train": dataset}
preprocess_fn = partial(
tokenize_example, tokenizer=tokenizer, max_length=args.max_length
)
tokenized_dataset = dataset.map(
preprocess_fn, remove_columns=dataset["train"].column_names
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_train_batch_size,
eval_strategy="steps" if args.eval_split_percentage > 0 else "no",
eval_steps=args.save_steps,
save_steps=args.save_steps,
logging_steps=args.logging_steps,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
save_total_limit=2,
fp16=torch.cuda.is_available(),
report_to="none",
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset.get("test"),
processing_class=tokenizer,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Optionally merge LoRA weights
merged_model = model.merge_and_unload()
merged_model.save_pretrained(os.path.join(args.output_dir, "merged"))
tokenizer.save_pretrained(os.path.join(args.output_dir, "merged"))
print(f"LoRA adapters saved to {args.output_dir}")
print(f"Merged model saved to {os.path.join(args.output_dir, 'merged')}")
if __name__ == "__main__":
main()