如何使用 Hugging Face TRL 调优多模态模型或视觉语言模型
多模态LLM最近取得了巨大的进展。我们现在已经有一个强大的开放多模态模型生态系统,主要是视觉-语言模型(VLM),包括Meta AI的Llama-3.2-11B-Vision,Mistral AI的Pixtral-12B,Qwen的Qwen2-VL-7B,以及Allen AI的Molmo-7B-D-0924。
这些VLMs可以处理各种多模态任务,包括图像字幕、视觉问答和图像-文本匹配,无需额外训练。然而,为了将模型定制化以适应您的特定应用,您可能需要在您的数据上微调模型,以实现更高质量的结果或为您的用例创建一个更高效的模型。
本文博客将指导您如何使用Hugging Face TRL、Transformers和datasets在2024年微调开放VLMs。
- 定义我们的多模式用例
- 设置开发环境
- 创建和准备多模态数据集
- 使用
trl微调VLM,并且SFTTrainer - 测试和评估VLM
注意:此博客是为消费者级别的GPU(24GB)创建的,例如NVIDIA A10G或RTX 4090/3090,但可以轻松改编以在更大的GPU上运行。
1. 定义我们的多模态使用案例
在微调VLMs时,明确定义您的使用案例和您希望解决的多模态任务至关重要。这将指导您选择基础模型,并帮助您创建适当的微调数据集。如果您尚未定义使用案例,您可能需要重新审视您的需求。
值得注意的是,对于大多数使用场景,微调可能并不是首选方案。我们建议在决定微调自己的模型之前,先评估预训练模型或基于API的解决方案。
作为一个例子,我们将使用以下多模态用例:
我们希望微调一个模型,该模型可以根据产品图片和基本元数据生成详细的产品描述。该模型将被整合到我们的电子商务平台中,以帮助卖家创建更具吸引力的列表。目标是减少创建产品描述所需的时间,并提高其质量和一致性。
现有的模型可能已经非常适合这个用例,但您可能需要根据您的具体需求进行调整/优化。这个图像到文本生成任务非常适合微调VLMs,因为它需要理解视觉特征,并将它们与文本信息结合起来,以生成连贯和相关的描述。我使用Gemini 1.5为这个用例创建了一个测试数据集philschmid/amazon-product-descriptions-vlm.
2. 设置开发环境
我们的第一步是安装 Hugging Face 库和 Pyroch,包括trl、transformers 和 datasets。如果你还没有听说过trl,别担心。trl 是一个基于 transformers 和 datasets 的库,它使微调、rlhf、对齐开放 LLM 更加容易。
# Install Pytorch & other libraries %pip install "torch==2.4.0" tensorboard pillow # Install Hugging Face libraries %pip install --upgrade \ "transformers==4.45.1" \ "datasets==3.0.1" \ "accelerate==0.34.2" \ "evaluate==0.4.3" \ "bitsandbytes==0.44.0" \ "trl==0.11.1" \ "peft==0.13.0" \ "qwen-vl-utils"
3. 创建和准备数据集
一旦你确定微调是正确的解决方案,我们需要创建一个数据集来微调我们的模型。我们必须将数据集准备成模型可以理解的格式。
在我们的例子中,我们将使用philschmid/amazon-product-descriptions-vlm,它包含1350个亚马逊产品,带有标题、图片和描述以及元数据。我们希望将我们的模型微调,以根据图片、标题和元数据生成产品描述。因此,我们需要创建一个包含标题、元数据和图片的提示,而完成内容是描述。
TRL 支持流行的指令和对话数据集格式。这意味着我们只需将数据集转换为支持的格式之一,trl其余的我们会处理。
"messages": [
{"role": "system", "content": [{"type":"text", "text": "You are a helpful...."}]},
{"role": "user", "content": [{
"type": "text", "text": "How many dogs are in the image?",
"type": "image", "text": <PIL.Image>
}]},
{"role": "assistant", "content": [{"type":"text", "text": "There are 3 dogs in the image."}]}
],
在我们的示例中,我们将使用Datasets库加载我们的数据集,并将其从pt格式转换为对话格式。
让我们从定义我们的指令提示开始。
from datasets import load_dataset # 请注意,prompt中没有提供图像,它作为"processor"的一部分包含在内。 prompt= """ 根据提供的##PRODUCT NAME##、##CATEGORY## 和图片创建简短的产品描述。 仅返回描述内容。该描述应针对搜索引擎优化(SEO),并为移动搜索提供更好的体验。 ## 产品名称 ##: {product_name} ## 类别 ##: {category} """ system_message = "你是一个专业的亚马逊产品描述作者。" # Convert dataset to OAI messages def format_data(sample): return {"messages": [ { "role": "system", "content": [{"type": "text", "text": system_message}], }, { "role": "user", "content": [ { "type": "text", "text": prompt.format(product_name=sample["Product Name"], category=sample["Category"]), },{ "type": "image", "image": sample["image"], } ], }, { "role": "assistant", "content": [{"type": "text", "text": sample["description"]}], }, ], } # Load dataset from the hub dataset_id = "philschmid/amazon-product-descriptions-vlm" dataset = load_dataset(dataset_id, split="train") print(dataset[0]) # 展示图片 from PIL import Image from IPython.display import display display(dataset[0]['image']) # Convert dataset to OAI messages # 需要使用列表推导来保持 Pil. Image 类型,.mape 将图像转换为字节 dataset = [format_data(sample) for sample in dataset] print(dataset[345]["messages"])

{
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x7FBAFB671C30>,
'Uniq Id': '002e4642d3ead5ecdc9958ce0b3a5a79',
'Product Name': 'Kurio Glow Smartwatch for Kids with Bluetooth, Apps, Camera & Games, Blue',
'Category': "Toys & Games | Kids' Electronics | Electronic Learning Toys",
'Selling Price': '$31.30',
'Model Number': 'C17515',
'About Product': 'Make sure this fits by entering your model number. | Kurio watch glow is a real Bluetooth Smartwatch built especially for kids, packed with 20+ apps & games! | Get your glow on with new light-up feature that turns games and activities into colorful fun. | Kurio watch glow includes brand-new games with light effects, including the My little dragon virtual pet and color-changing mood sensor. | Play single and two-player games on one watch, Or connect two watches together via Bluetooth, plus motion-sensitive games that get kids moving! | Take fun selfies with the front-facing camera and decorate them with filters, frames and stickers. | Plus, everything you need in a smartwatch – activity tracker, music player, Alarm/stopwatch, calculator, calendar and so much more! | Scratch resistant and splash-proof - suitable for kids ages 4 and up!',
'Product Specification': 'ProductDimensions:5x3x12inches|ItemWeight:7.2ounces|ShippingWeight:7.2ounces(Viewshippingratesandpolicies)|ASIN:B07TFD5D55|Itemmodelnumber:C17515|Manufacturerrecommendedage:4yearsandup|Batteries:1LithiumPolymerbatteriesrequired.(included)',
'Technical Details': "Color:Blue show up to 2 reviews by default This sleek, hi-tech Bluetooth Smartwatch is made specifically for kids, and it's packed with apps and games for out-of-the-box fun! Take selfies and videos, play single and two-player games, message friends, listen to music, plus everything you need in a smartwatch– activity tracker, alarm/stopwatch, calculator, calendar and so much more! Plus, parents can add vital information like blood type and allergies to an 'in case of an emergency' (I. C. E. ) app | 7.2 ounces (View shipping rates and policies)",
'Shipping Weight': '7.2 ounces',
'Variants': 'https://www.amazon.com/Kurio-Smartwatch-Bluetooth-Camera-Games/dp/B07TFD5D55|https://www.amazon.com/Kurio-Smartwatch-Bluetooth-Camera-Games/dp/B07TD8JHKW',
'Product Url': 'https://www.amazon.com/Kurio-Smartwatch-Bluetooth-Camera-Games/dp/B07TFD5D55',
'Is Amazon Seller': 'Y',
'description': "Kurio Glow Smartwatch: Fun, Safe & Educational! This kids' smartwatch boasts Bluetooth connectivity, built-in apps & games, and a camera – all in a vibrant blue design. Perfect for learning & play! #kidssmartwatch #kidselectronics #educationaltoys #kurioglow"
}
[
{
'role': 'system',
'content': [{
'type': 'text',
'text': '你是一个专业的亚马逊产品描述作者。'}]
},
{
'role': 'user',
'content': [
{
'type': 'text',
'text': '根据提供的##PRODUCT NAME##、##CATEGORY## 和图片创建简短的产品描述。仅返回描述内容。该描述应针对搜索引擎优化(SEO),并为移动搜索提供更好的体验。\n## 产品名称 ##: MasterPieces Tribal Spirit Jigsaw Puzzle, The Chiefs, Featuring American Indian Tribe Traditions & Ceremonies, 1000 Pieces\n## 类别 ##: Toys & Games | Puzzles | Jigsaw Puzzles'
},
{
'type': 'image',
'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x7FBAFA6DDFF0>
}
]
},
{
'role': 'assistant',
'content': [
{
'type': 'text',
'text': 'Challenge yourself with this 1000-piece MasterPieces Tribal Spirit jigsaw puzzle! Depicting the rich traditions and ceremonies of American Indian tribes, "The Chiefs" offers a stunning, culturally significant image perfect for puzzle enthusiasts. High-quality pieces guarantee a satisfying solve.'
}
]
}
]
4. 使用 trl 和 SFTTrainer微调VLM
我们现在已经准备好微调我们的模型。我们将使用 SFTTrainer 从 trl 来微调我们的模型。 SFTTrainer 使监督开放LLM和VLM的微调变得简单。 SFTTrainer 是 Trainer 的一个子类 transformers 库,并支持所有相同的功能,包括日志记录、评估和检查点,但增加了额外的生活质量功能。
在我们的示例中,我们将使用PEFT功能。作为PEFT方法,我们将使用QLoRA,这是一种在微调过程中减少大语言模型内存占用的技术,同时通过使用量化来不牺牲性能。如果你想了解更多关于QLoRA及其工作原理的信息,请查看 使用bitsandbytes、4比特量化和QLoRA让LLM更易获取 博客文章。
注意:由于我们需要填充我们的多模态输入,我们不能使用Flash Attention。
我们将使用Qwen 2 VL 7B模型,当然也可以轻松地将模型更换为其他模型,包括Meta AI的Llama-3.2-11B-Vision、Mistral AI的Pixtral-12B或任何其他LLMs,只需更改我们的model_id变量。我们将使用bitsandbytes将模型量化为4位。
注意:模型越大,所需的内存就越多。在我们的示例中,我们将使用7B版本,该版本可以在24GB的GPU上微调。
正确地为训练VLMs准备LLM、Tokenizer和Processor是至关重要的。Processor负责将特殊标记和图像特征包含在输入中。
import torch from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig # Hugging Face model id model_id = "Qwen/Qwen2-VL-7B-Instruct" # BitsAndBytesConfig int-4 config bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # Load model and tokenizer model = AutoModelForVision2Seq.from_pretrained( model_id, device_map="auto", # attn_implementation="flash_attention_2", # not supported for training torch_dtype=torch.bfloat16, quantization_config=bnb_config ) processor = AutoProcessor.from_pretrained(model_id)
Loading checkpoint shards: 100%|██████████| 5/5 [02:31<00:00, 30.26s/it]
# Preparation for inference text = processor.apply_chat_template( dataset[2]["messages"], tokenize=False, add_generation_prompt=False ) print(text)
<|im_start|>system 你是一个专业的亚马逊产品描述作者。<|im_end|> <|im_start|>user 根据提供的##PRODUCT NAME##、##CATEGORY## 和图片创建简短的产品描述。仅返回描述内容。该描述应针对搜索引擎优化(SEO),并为移动搜索提供更好的体验。
## 产品名称 ##: Barbie Fashionistas Doll Wear Your Heart
## 类别 ##: Toys & Games | Dolls & Accessories | Dolls
<|vision_start|><|image_pad|><|vision_end|>
<|im_end|> <|im_start|>assistant Express your style with Barbie Fashionistas Doll Wear Your Heart! This fashionable doll boasts a unique outfit and accessories, perfect for imaginative play. A great gift for kids aged 3+. Collect them all! #Barbie #Fashionistas #Doll #Toys #GirlsToys #FashionDoll #Play<|im_end|>
该 SFTTrainer 支持与 peft的原生集成,这使得使用例如QLoRA等方法高效地调整LLMs变得超级简单。我们只需要创建我们的 LoraConfig 并将其提供给训练器。我们的LoraConfig参数是基于QLoRA论文和sebastian的博客文章定义的。
from peft import LoraConfig # LoRA config based on QLoRA paper & Sebastian Raschka experiment peft_config = LoraConfig( lora_alpha=16, lora_dropout=0.05, r=8, bias="none", target_modules=["q_proj", "v_proj"], task_type="CAUSAL_LM", )
在我们开始训练之前,我们需要定义要使用的超参数(SFTConfig),并确保输入正确地提供给模型。与仅文本的监督微调不同,我们需要将图像提供给模型。因此,我们创建了一个自定义的DataCollator,它正确地格式化输入并包含图像特征。我们使用process_vision_info方法,该方法由Qwen2团队提供的实用工具包提供。如果你使用的是其他模型,例如Llama 3.2 Vision,你可能需要检查该模型是否生成相同处理的图像信息。
from trl import SFTConfig from transformers import Qwen2VLProcessor from qwen_vl_utils import process_vision_info args = SFTConfig( output_dir="qwen2-7b-instruct-amazon-description", # 保存目录和仓库ID num_train_epochs=3, # 训练的轮数 per_device_train_batch_size=4, # 每个设备的训练批次大小 gradient_accumulation_steps=8, # 执行反向传播/更新前的累积步数 gradient_checkpointing=True, # 使用梯度检查点以节省内存 optim="adamw_torch_fused", # 使用融合的adamw优化器 logging_steps=5, # 每5步记录一次日志 save_strategy="epoch", # 每轮保存一次检查点 learning_rate=2e-4, # 学习率,基于QLoRA论文 bf16=True, # 使用bfloat16精度 tf32=True, # 使用tf32精度 max_grad_norm=0.3, # 最大梯度范数,基于QLoRA论文 warmup_ratio=0.03, # 预热比例,基于QLoRA论文 lr_scheduler_type="constant", # 使用恒定学习率调度器 # push_to_hub=False, # 将模型推送到hub report_to="tensorboard", # 向tensorboard报告指标 gradient_checkpointing_kwargs={"use_reentrant": False}, # 使用非重入式检查点 dataset_text_field="", # 为collator需要的占位字段 dataset_kwargs={"skip_prepare_dataset": True} # 对collator很重要 ) args.remove_unused_columns=False # Create a data collator to encode text and image pairs ''' collate_fn 是 PyTorch DataLoader 中的一个关键函数,用于将多个分散的样本(examples,通常是列表形式)整合成一个批次(batch)数据。 examples 是输入的样本列表,每个样本是一个字典,包含 messages(对话文本信息)和图像相关数据。 ''' def collate_fn(examples): # Get the texts and images, and apply the chat template ''' 1. 提取文本和图像输入 文本处理: processor.apply_chat_template(example["messages"], tokenize=False)
对每个样本的对话消息(example["messages"])应用模型要求的对话模板(如添加角色标识、分隔符等),返回格式化的文本字符串(tokenize=False 表示先不进行分词,仅做格式转换)。 最终 texts 是一个列表,包含所有样本的格式化文本。 图像处理: process_vision_info(example["messages"])
用于从对话消息中提取图像信息(如图像路径、像素数据等)并预处理(如 resize、归一化)。[0] 表示取处理后的第一个图像(通常一个对话对应一张图像)。 最终 image_inputs 是一个列表,包含所有样本的预处理后图像数据。 ''' texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] image_inputs = [process_vision_info(example["messages"])[0] for example in examples] # Tokenize the texts and process the images ''' 2. 文本分词与图像处理(构建批次) processor 是多模态处理器(如 Qwen2VLProcessor、LlavaProcessor 等),可同时处理文本和图像: text=texts:传入上一步处理的格式化文本列表,处理器会对其进行分词(转换为 input_ids、attention_mask 等)。 images=image_inputs:传入预处理后的图像列表,处理器会将图像转换为模型所需的张量格式(如像素值归一化、维度调整)。 return_tensors="pt":指定返回 PyTorch 张量(适配 PyTorch 模型)。 padding=True:对批次内的文本进行填充(用 pad_token 补全到最长序列长度),确保批次内所有样本的 input_ids 长度一致。 最终 batch 是一个字典,包含 input_ids(文本分词后的索引)、attention_mask(标记有效 token 位置)、pixel_values(图像张量)等关键数据。 ''' batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True) # The labels are the input_ids, and we mask the padding tokens in the loss computation ''' 3. 构建标签(用于损失计算) 在语言模型训练中,通常用输入序列(input_ids)作为标签(预测下一个 token),但需要忽略填充 token(pad_token)的损失(因为填充是人为添加的,无实际意义)。 labels = batch['input_ids'].clone():
复制 input_ids 作为初始标签。 labels[labels == processor.tokenizer.pad_token_id] = -100:
将标签中所有等于 pad_token_id(填充 token 的索引)的位置设为 -100。PyTorch 的交叉熵损失函数(CrossEntropyLoss)会自动忽略值为 -100 的标签,不计算其损失。 ''' labels = batch['input_ids'].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 # Ignore the image token index in the loss computation (model specific) ''' 4. 忽略图像相关 token 的损失 多模态模型中,文本序列中会插入特殊的 “图像 token”(如 <image>),用于标识图像在文本中的位置(将图像特征与文本特征对齐)。
这些 token 是输入的一部分,不是模型需要预测的内容,因此需要在损失计算中忽略。 根据处理器类型(如 Qwen2VLProcessor 有特定的图像 token),定义图像相关 token 的索引(image_tokens)。 遍历这些图像 token,将标签中对应位置的值设为 -100,确保损失计算时忽略它们。 ''' if isinstance(processor, Qwen2VLProcessor): image_tokens = [151652,151653,151655] # 在 Qwen2-VL 模型(阿里云推出的多模态模型)中,image_tokens = [151652, 151653, 151655] 对应的是图像相关的特殊 token 的 ID,
# 用于在文本序列中标记图像的位置、类型或边界,实现文本与图像的对齐(多模态融合)。 else: image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] ''' 不同多模态模型的 image_token 可能不同(由模型设计决定): 例如 LLaVA 模型的 image_token 通常是 <image>; 而 Qwen2-VL 模型不使用单一的 image_token,而是用一组特殊 token(如之前提到的 [151652, 151653, 151655])
来标记图像相关位置,因此它的处理器(Qwen2VLProcessor)可能没有 image_token 属性,需要单独指定这组 ID。 ''' for image_token_id in image_tokens: labels[labels == image_token_id] = -100 batch["labels"] = labels return batch
我们现在已经有创建我们所需的所有基本要素 SFTTrainer 以开始训练我们的模型。
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset,
data_collator=collate_fn,
dataset_text_field="", # needs dummy value
peft_config=peft_config,
tokenizer=processor.tokenizer,
)
通过调用 train() 方法来开始训练我们的 Trainer 实例。这将启动训练循环,并训练我们的模型3个完整的周期。由于我们使用的是PEFT方法,我们只会保存调整后的模型权重,而不是整个模型。
# start training, the model will be automatically saved to the hub and the output directory trainer.train() # save model trainer.save_model(args.output_dir)
[126/126 46:54, Epoch 2/3]
Step Training Loss 5 2.818500 10 2.499200 15 2.114200 20 1.682500 25 1.315200 30 1.119800 35 1.022300 40 0.987600 45 0.937300 50 0.910600 55 0.899600 60 0.854100 65 0.831400 70 0.801400 75 0.789800 80 0.794500 85 0.779300 90 0.749600 95 0.754200 100 0.737100 105 0.733300 110 0.736100 115 0.732700 120 0.743100 125 0.745500
用大约1千个样本的训练数据训练3个epoch花费了01:31:58在一个aws云主机g6.2xlarge上。实例费用为0.9776$/h,这使我们总费用仅为1.4$。
# free the memory again del model del trainer torch.cuda.empty_cache()
4. 测试模型并运行推理
训练完成后,我们希望评估和测试我们的模型。首先,我们会加载基础模型,并让它生成一个随机的亚马逊产品的描述。然后,我们会加载我们调整后的Q-LoRA模型,并让它生成相同产品的描述。
最后,我们可以将适配器合并到基础模型中,使其更高效,并再次对该产品进行推理。
import torch from transformers import AutoProcessor, AutoModelForVision2Seq adapter_path = "./qwen2-7b-instruct-amazon-description" # Load Model base model model = AutoModelForVision2Seq.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 ) processor = AutoProcessor.from_pretrained(model_id)
Loading checkpoint shards: 100%|██████████| 5/5 [00:07<00:00, 1.45s/it]
我从亚马逊上随机选择了一款产品,并准备了一个generate_description函数来生成该产品的描述。
from qwen_vl_utils import process_vision_info # sample from amazon.com sample = { "product_name": "Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held, Iron Man, 30,5 cm Actionfigur", "catergory": "Toys & Games | Toy Figures & Playsets | Action Figures", "image": "https://m.media-amazon.com/images/I/81+7Up7IWyL._AC_SY300_SX300_.jpg" } # prepare message messages = [{ "role": "user", "content": [ { "type": "image", "image": sample["image"], }, {"type": "text", "text": prompt.format(product_name=sample["product_name"], category=sample["catergory"])}, ], } ] def generate_description(sample, model, processor): messages = [ {"role": "system", "content": [{"type": "text", "text": system_message}]}, {"role": "user", "content":[ {"type": "image", "image": sample["image"]}, {"type": "text", "text": prompt.format(product_name=sample["product_name"], category=sample["catergory"])}, ]}, ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize = False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(model.device) # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=256, top_p=1.0, do_sample=True, temperature=0.8) generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return output_text[0]
太棒了,它工作正常!让我们加载我们的适配器并与基础模型进行比较。
from PIL import Image import requests from io import BytesIO # 发送请求获取图片数据 response = requests.get(sample['image']) response.raise_for_status() # 检查请求是否成功 # 将二进制数据转换为PIL可处理的对象 img = Image.open(BytesIO(response.content)) # 显示图片 img.show() # let's generate the description base_description = generate_description(sample, model, processor) print(base_description) # you can disable the active adapter if you want to rerun it with # model.disable_adapters()

这款Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held Iron Man Actionfigur是一款高度详细的30.5厘米铁人英雄动作人偶,是任何超级英雄迷的必备收藏品。这款人偶具有逼真的细节和生动的色彩,是孩子们的完美玩具,也是成年人的收藏佳品。它不仅适合儿童玩耍,也适合作为礼物或装饰品。这款人偶是Marvel系列的一部分,是粉丝们喜爱的英雄之一,非常适合用来展示或收藏。
model.load_adapter(adapter_path) # load the adapter and activate ft_description = generate_description(sample, model, processor) print(ft_description)
Unleash the power of Iron Man with this Hasbro Marvel Avengers Titan Hero Series 30.5 cm action figure!
This highly detailed Iron Man figure is perfect for collectors and kids alike. Features Titan Hero port for compatible accessories (sold separately). A must-have for any Marvel fan!
让我们将它们并排比较,使用 markdown 表格。
import pandas as pd from IPython.display import display, HTML def compare_generations(base_gen, ft_gen): # Create a DataFrame df = pd.DataFrame( { 'Base Generation': [base_gen], 'Fine-Truned Gen': [ft_gen] } ) # Style the DataFrame styled_df = df.style.set_properties( **{ 'text-align': 'left', 'white-space': 'pre-wrap', 'border': '1px solid black', 'padding': '10px', 'width': '250px', # Set width to 150px 'overflow-wrap': 'break-word' # Allow words to break and wrap as needed } ) # Display the styled DataFrame display(HTML(styled_df.to_html())) compare_generations(base_description, ft_description)
| Base Generation | Fine-Truned Gen | |
|---|---|---|
| 0 | 这款Hasbro Marvel Avengers-Serie Marvel Assemble Titan-Held Iron Man Actionfigur是一款高度详细的30.5厘米铁人英雄动作人偶,是任何超级英雄迷的必备收藏品。这款人偶具有逼真的细节和生动的色彩,是孩子们的完美玩具,也是成年人的收藏佳品。它不仅适合儿童玩耍,也适合作为礼物或装饰品。这款人偶是Marvel系列的一部分,是粉丝们喜爱的英雄之一,非常适合用来展示或收藏。 |
Unleash the power of Iron Man with this Hasbro Marvel Avengers Titan Hero Series 30.5 cm action figure! This highly detailed Iron Man figure is perfect for collectors and kids alike. Features Titan Hero port for compatible accessories (sold separately). A must-have for any Marvel fan! |
不错!尽管我们刚刚有 ~1k 个样本,我们仍然可以看到微调提高了产品描述生成。描述更短更简洁,这符合我们的训练数据。
可选:将LoRA适配器合并到原始模型中
在使用QLoRA时,我们只训练适配器而不是整个模型。这意味着在训练过程中保存模型时,我们只保存适配器权重而不是整个模型。如果你希望保存整个模型,以便于与文本生成推理一起使用,可以使用merge_and_unload方法将适配器权重合并到模型权重中,然后使用save_pretrained方法保存模型。这将保存一个默认模型,可以用于推理。
注意:这需要超过30GB的CPU内存。
from peft import PeftModel from transformers import AutoProcessor, AutoModelForVision2Seq adapter_path = "./qwen2-7b-instruct-amazon-description" base_model_id = "Qwen/Qwen2-VL-7B-Instruct" merged_path = "merged" # Load Model base model model = AutoModelForVision2Seq.from_pretrained(model_id, low_cpu_mem_usage=True) # Path to save the merged model # Merge LoRA and base model and save peft_model = PeftModel.from_pretrained(model, adapter_path) merged_model = peft_model.merge_and_unload() merged_model.save_pretrained(merged_path, safe_serialization=True, max_shard_size="2GB") processor = AutoProcessor.from_pretrained(base_model_id) processor.save_pretrained(merged_path)
Loading checkpoint shards: 100%|██████████| 5/5 [01:34<00:00, 18.81s/it]
奖励:使用TRL示例脚本
TRL提供了一个简单的示例脚本来微调多模态模型。你可以在这里找到这个脚本。该脚本可以直接从命令行运行,并支持SFTTrainer的所有功能。
# Tested on 8x H100 GPUs !accelerate launch \ --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ examples/scripts/sft_vlm.py \ --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ --model_name_or_path llava-hf/llava-1.5-7b-hf \ --per_device_train_batch_size 8 \ --gradient_accumulation_steps 8 \ --output_dir sft-llava-1.5-7b-hf \ --bf16 \ --torch_dtype bfloat16 \ --gradient_checkpointing

浙公网安备 33010602011771号