使用trl-qlora微调qwen2.5
大型语言模型(LLMs)在过去一年中取得了许多进展。我们从现在ChatGPT的竞争对手发展到一个包含Meta AI的Llama 3,Mistral的Mistral和Mixtral模型,TII的Falcon,以及许多其他模型。 这些LLMs可以用于各种任务,包括聊天机器人、问答、无需额外训练的摘要。然而,如果你希望为应用程序定制模型,你可能需要在自有数据上微调模型,以获得比提示更好的结果,或者通过训练更小的模型来节省成本和提高效率。
这篇文章将指导你如何使用Hugging Face进行开放LLM的微调TRL和Transformers。你将学习如何:
- 什么是Q-LoRA以及微调模型的系统要求
- 设置开发环境
- 创建和准备数据集
- 使用
trl
微调LLMSFTTrainer
- 测试和评估LLM
注意:此博客是为了在使用NVIDIA T4 GPU和16GB内存的免费Google colaboratory账户上运行而创建的,但可以轻松改编以在更大的GPU和更大的模型上运行,请参见下面的内存要求。
1. Q-LoRA是什么,以及微调模型的系统要求
Quantized Low-Rank Adaptation (QLoRA) 已经成为一种流行的方法,用于高效地微调 LLMs,因为它在大幅减少计算资源需求的同时保持高性能。在 Q-LoRA 中,预训练模型被量化到 4 位,并且权重被冻结。然后附加可训练的适配器层 (LoRA),并且只训练适配器层。之后,适配器权重可以与基础模型合并,或者保持为单独的适配器。
QLoRA的内存效率使得微调在各种硬件配置上成为可能。
模型大小 | 最低显存要求 |
---|---|
1B | 5 GB |
4B | 12 GB |
12B | 24 GB |
27B | 40 GB |
2. 设置开发环境
第一步是安装Hugging Face库,包括trl和transformers。如果你还没有听说过trl,不要担心。trl是一个基于transformers和datasets的库,它使微调、rlhf、对齐开放LLM更加容易。
# Install Pytorch & other libraries !pip install "torch>=2.4.0" tensorboard # Install Hugging Face libraries !pip install --upgrade \ "transformers==4.49.0" \ "datasets==3.3.2" \ "accelerate==1.4.0" \ "evaluate==0.4.3" \ "bitsandbytes==0.45.3" \ "trl==0.15.2" \ "peft==0.14.0" # COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, e.g. L4 !pip install flash-attn
注意:如果你使用的是具有Ampere架构的GPU(例如NVIDIA L4)或更新的型号,你可以使用Flash attention。Flash Attention是一种重新排序注意力计算并利用经典技术(平铺、重新计算)的方法,可以显著加快计算速度并减少内存使用量,从序列长度的平方减少到线性。简而言之;加速训练最多3倍。了解更多:FlashAttention.
3. 创建和准备数据集
在微调 LLMs 时,重要的是你知道你的使用案例和你想要解决的任务。这将帮助你创建一个用于微调模型的数据集。如果你还没有定义你的使用案例,你可能需要从头开始。
例如,这个博客关注以下用例:
微调一个模型,该模型可以根据自然语言指令生成SQL查询,然后可以将其集成到数据分析工具中。目标是减少创建SQL查询所需的时间,并使非技术用户更轻松地创建SQL查询。
文本到SQL可以成为微调LLMs的一个很好的用例,因为这是一个需要对数据和SQL语言有大量(内部)知识的复杂任务。
一旦你确定微调是正确的解决方案,我们需要创建一个数据集来微调我们的模型。该数据集应为所需解决的任务提供多样化的示例。创建此类数据集的方法有多种,包括:
每种方法都有其自身的优点和缺点,并且取决于预算、时间以及质量要求。例如,使用现有的数据集是最简单的,但可能无法针对你的特定用例进行定制,而使用人工可能最准确,但可能耗时且昂贵。也可以将几种方法结合起来创建一个指令数据集,如 Orca: Progressive Learning from Complex Explanation Traces of GPT-4。
这个例子使用了一个已经存在的数据集 (gretelai/synthetic_text_to_sql),一个高质量的合成文本到SQL数据集,包括自然语言指令、模式定义、推理和相应的SQL查询。
Hugging Face TRL 支持对话数据集格式的自动模板化。这意味着我们只需将数据集转换为正确的json对象,trl
将负责模板化并将其转换为正确的格式。
{"messages":[{"role":"system","content":"你是..."},{"role":"user","content":"..."},{"role":"assistant","content":"..."}]} {"messages":[{"role":"system","content":"You are..."{"role":"user","content":"..."{"role":"assistant","content":"..."}]}
该gretelai/synthetic_text_to_sql包含超过10万条样本,为了保持示例简单,它将数据集减少到只有10,000条样本。
{
'id': 66411,
'domain': 'fitness industry',
'domain_description': 'Workout data, membership demographics, wearable technology metrics, and wellness trends.',
'sql_complexity': 'single join',
'sql_complexity_description': 'only one join (specify inner, outer, cross)',
'sql_task_type': 'analytics and reporting', 'sql_task_type_description': 'generating reports, dashboards, and analytical insights', 'sql_prompt': 'How many members attended each type of class in April 2022?',
'sql_context': "CREATE TABLE Members (MemberID INT, Age INT, Gender VARCHAR(10), MembershipType VARCHAR(20)); INSERT INTO Members (MemberID, Age, Gender, MembershipType) VALUES (1, 35, 'Female', 'Premium'), (2, 45, 'Male', 'Basic'), (3, 28, 'Female', 'Premium'), (4, 32, 'Male', 'Premium'), (5, 48, 'Female', 'Basic'); CREATE TABLE ClassAttendance (MemberID INT, Class VARCHAR(20), Date DATE); INSERT INTO ClassAttendance (MemberID, Class, Date) VALUES (1, 'Cycling', '2022-04-01'), (2, 'Yoga', '2022-04-02'), (3, 'Cycling', '2022-04-03'), (4, 'Yoga', '2022-04-04'), (5, 'Pilates', '2022-04-05'), (1, 'Cycling', '2022-04-06'), (2, 'Yoga', '2022-04-07'), (3, 'Cycling', '2022-04-08'), (4, 'Yoga', '2022-04-09'), (5, 'Pilates', '2022-04-10');", 'sql': 'SELECT Class, COUNT(*) as AttendanceCount FROM Members JOIN ClassAttendance ON Members.MemberID = ClassAttendance.MemberID WHERE MONTH(ClassAttendance.Date) = 4 GROUP BY Class;', 'sql_explanation': 'We are joining the Members table with the ClassAttendance table based on the MemberID. We then filter the records to only those where the month of the Date is April and group the records by the Class and calculate the attendance count using the COUNT function.' }
现在你可以使用🤗 Datasets库来加载数据集,并创建一个提示模板,将自然语言指令、模式定义结合起来,并为我们的助手添加一个系统消息。
注意:此步骤可能根据你的使用案例而异。例如,如果你已经从 OpenAI 等渠道获得了一个数据集,你可以跳过此步骤并直接进行微调步骤。
from datasets import load_dataset # System message for the assistant system_message = """你是一个文本 SQL 查询翻译器,用户会用英文问你问题,你会根据提供的 ScheMA 生成一个 SQL 查询""" # User prompt that combines the user query and the schema user_prompt = """给定 <USER_QUERY> 和<ScheMA>,生成相应的 SQL 命令以检索所需数据,同时考虑查询的语法、语义学和模式约束。 <SCHEMA> {context} </SCHEMA> <USER_QUERY> {question} </USER_QUERY> """ def create_conversation(sample): return { "messages": [ {"role": "system", "content": system_message}, {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])}, {"role": "assistant", "content": sample["sql"]} ] } # Load dataset from the hub dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train") dataset = dataset.shuffle().select(range(12500)) ''' load_dataset函数主要用于从本地或 Hugging Face 模型库中获取所需的数据集。如果数据集仅包含数据文件,它会根据文件扩展名(如 json、csv、parquet、txt 等)自动推断如何加载数据文件。若数据集有数据集脚本,则会从模型库中下载并导入该脚本,根据脚本中的代码来生成和显示数据集示例。 - 功能: load_dataset函数主要用于从本地或 Hugging Face 模型库中获取所需的数据集。如果数据集仅包含数据文件,它会根据文件扩展名(如 json、csv、parquet、txt 等)自动推断如何加载数据文件。若数据集有数据集脚本,则会从模型库中下载并导入该脚本,根据脚本中的代码来生成和显示数据集示例。 - 用法示例: - 加载 Hugging Face 模型库中的数据集: 可以直接传入数据集在模型库中的名称来加载,如from datasets import load_dataset; dataset = load_dataset("imdb"),这将加载 IMDB 影评数据集。 - 加载本地数据集: 若数据集为 CSV 格式,可使用from datasets import load_dataset; load_dataset('csv', data_files='path/to/your/dataset.csv')来加载,其中data_files参数指定了本地数据集文件的路径。 - 加载特定版本的数据集: 部分数据集可能有多个版本,可通过revision参数指定版本,如dataset = load_dataset("namespace/your_dataset_name", revision="specific_version")。 - 指定数据划分: 可以通过data_files参数将数据文件映射到训练集、验证集和测试集等划分,如dataset = load_dataset("namespace/your_dataset_name", data_files=data_files),其中data_files是一个包含各划分数据文件路径的字典。 ''' # Convert dataset to OAI messages dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False) # split dataset into 10,000 training samples and 2,500 test samples dataset = dataset.train_test_split(test_size=2500/12500) # Print formatted user prompt print(dataset["train"][345]["messages"][1]["content"])
Map: 100%|██████████| 12500/12500 [00:01<00:00, 6483.04 examples/s] 给定 <USER_QUERY> 和<ScheMA>,生成相应的 SQL 命令以检索所需数据,同时考虑查询的语法、语义学和模式约束。 <SCHEMA> CREATE TABLE Projects (project_id INT, contractor_id INT, start_date DATE, end_date DATE); </SCHEMA> <USER_QUERY> List all projects with a start date on or after "2022-01-01" from the "Projects" table. </USER_QUERY>
4. 使用 trl
精调 LLM SFTTrainer
现在你可以微调你的模型了。Hugging Face TRL SFTTrainer 使监督微调开放的LLM变得简单。 SFTTrainer
是 Trainer
的一个子类 transformers
库,并支持所有相同的功能,包括日志记录、评估和检查点,但增加了额外的生活质量功能,包括:
- 数据集格式化,包括对话和指令格式
- 仅对完成进行训练,忽略提示
- 打包数据集以实现更高效的训练
- PEFT(参数高效微调)支持包括Q-LoRA
- 为对话微调准备模型和分词器(例如,添加特殊标记)
这个例子使用了Gemma,你可以通过更改model_id
变量来换用其他版本。我们将使用bitsandbytes将我们的模型量化到4位。
注意:模型越大,所需的内存就越多。在我们的示例中,我们将使用8B版本,该版本可以在24GB的GPU上微调。如果你有一个较小的GPU。
import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig # Hugging Face model id model_id = "Qwen/Qwen2.5-3B-Instruct" # Check if GPU benefits from bfloat16 if torch.cuda.get_device_capability()[0] >= 8: torch_dtype = torch.bfloat16 else: torch_dtype = torch.float16 # define model kwargs model_kwargs = dict( attn_implementation="sdpa", # 注意力实现机制,可以使用flash_attention_2代替 torch_dtype=torch_dtype, # 使用的torch dtype类型,默认为自动 use_cache=False, # 我们使用梯度检查点 device_map="auto", # define model kwargs ) model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=model_kwargs['torch_dtype'], bnb_4bit_quant_storage=model_kwargs['torch_dtype'], ) ''' BitsAndBytesConfig 是 Hugging Face transformers 库中用于配置模型量化参数的类,主要与 bitsandbytes 库配合使用,实现模型的低精度量化(如 4 位、8 位量化)。其核心作用是通过量化模型权重和激活值,在几乎不损失性能的前提下,大幅降低模型的内存占用,使得大型模型(如 10B、30B 参数)可以在消费级 GPU 上运行或训练。 核心功能 支持 4 位(4-bit)和 8 位(8-bit)量化,平衡模型性能与内存占用。 提供多种量化策略(如双量化、NF4 量化),优化量化精度。 可直接与 AutoModelForCausalLM 等模型加载方法结合,简化量化流程。 主要参数说明 BitsAndBytesConfig 的常用参数如下(详细参数可参考官方文档): 参数 类型 说明 load_in_4bit bool 是否启用 4 位量化(与load_in_8bit互斥) load_in_8bit bool 是否启用 8 位量化(与load_in_4bit互斥) bnb_4bit_use_double_quant bool 4 位量化时是否使用双量化(减少量化误差,推荐启用) bnb_4bit_quant_type str 4 位量化类型,可选"fp4"(浮点 4 位)或"nf4"(归一化浮点 4 位,推荐,对小数据集更稳定) bnb_4bit_compute_dtype torch.dtype 计算时使用的数据类型(如torch.float16或torch.bfloat16,影响计算精度和速度) 注意事项 1.** 设备兼容性 :4 位量化需要 GPU 支持 CUDA(不支持 CPU),且推荐 GPU 计算能力≥7.0(如 RTX 2000 系列及以上)。 2. 模型限制 :并非所有模型都支持量化,主要适用于AutoModelForCausalLM、AutoModelForSeq2SeqLM等生成式模型。 3. 性能权衡 :4 位量化内存占用约为 FP16 的 1/4,8 位约为 1/2,但可能损失少量精度(具体因模型和任务而异)。 4. 缓存清理 **:若修改量化配置后加载模型异常,可删除~/.cache/huggingface/hub下的缓存文件重试。 通过 BitsAndBytesConfig,可以轻松实现大模型的低精度部署和训练,特别适合资源有限的场景(如个人 GPU、小型服务器)。 ''' # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) ''' AutoModelForCausalLM.from_pretrained 是 Hugging Face transformers 库中用于加载预训练因果语言模型(如 GPT、LLaMA、Mistral 等)的核心函数。它属于 “AutoModel” 系列,能自动根据模型名称或路径推断对应的模型架构,无需手动指定具体模型类(如 GPT2LMHeadModel、LlamaForCausalLM 等),极大简化了模型加载流程。 核心作用 从 Hugging Face Hub(模型库)或本地路径加载预训练的因果语言模型(支持文本生成任务)。 自动处理模型权重的加载、架构匹配、设备分配(CPU/GPU)等细节。 支持量化、并行计算、自定义配置等高级功能,适配不同硬件和场景(推理、微调等)。 关键参数详解 from_pretrained 函数有众多参数,以下是最常用的核心参数: 参数 类型 说明 pretrained_model_name_or_path str 必填,模型名称(如 "HuggingFaceH4/zephyr-7b-beta")或本地模型文件夹路径。 config PretrainedConfig 模型配置对象,若未指定,会自动从模型路径加载。 from_tf bool 是否从 TensorFlow 格式的权重加载(默认False)。 from_flax bool 是否从 Flax 格式的权重加载(默认False)。 cache_dir str 模型缓存路径(默认~/.cache/huggingface/hub),可指定本地文件夹保存下载的模型。 device_map str/dict 设备分配策略: - "auto":自动分配到可用设备(优先 GPU,剩余 CPU); - "cpu":强制加载到 CPU; - "cuda"/0:加载到指定 GPU(如"cuda:0"); - 字典:手动指定各层到设备的映射(适合大模型分片)。 quantization_config QuantizationConfig 量化配置(如BitsAndBytesConfig),用于低精度加载模型(4 位 / 8 位量化),降低内存占用。 load_in_8bit/load_in_4bit bool 快速启用 8 位 / 4 位量化(等价于配置quantization_config的简化方式),默认False。 trust_remote_code bool 是否信任模型中包含的远程代码(如自定义架构),加载非标准模型时需设为True(默认False)。 revision str 模型版本(如分支名、commit 哈希),用于加载特定版本的模型(默认"main")。 ''' tokenizer = AutoTokenizer.from_pretrained(model_id) ''' AutoTokenizer.from_pretrained 是 Hugging Face transformers 库中用于加载预训练分词器的核心函数。分词器(Tokenizer)的作用是将原始文本转换为模型可识别的输入格式(如 token ID、注意力掩码等),是连接自然语言与模型输入的 “桥梁”。 AutoTokenizer 系列的优势在于:无需手动指定具体分词器类(如 GPT2Tokenizer、LlamaTokenizer),能自动根据模型名称或路径推断对应的分词器类型,极大简化了预处理流程。 核心作用 加载与预训练模型配套的分词器(确保文本处理规则与模型训练时一致)。 提供文本到 token 的转换(分词)、token 到 ID 的映射、添加特殊符号(如 [CLS]、[SEP])、生成注意力掩码(attention mask)等功能。 支持批量处理、截断(truncation)、填充(padding)等预处理操作,适配不同长度的文本输入。 关键参数详解 from_pretrained 函数的常用参数如下: 参数 类型 说明 pretrained_model_name_or_path str 必填,模型名称(如 "meta-llama/Llama-2-7b-chat-hf")或本地分词器文件夹路径(包含tokenizer_config.json等文件)。 cache_dir str 分词器缓存路径(默认~/.cache/huggingface/hub),可指定本地文件夹保存下载的分词器文件。 use_fast bool 是否使用快速分词器(基于 Rust 实现),速度更快,功能更全(默认True)。若模型不支持快速分词器,会自动降级为慢速(Python 实现)。 padding_side str 填充(padding)的方向,可选"left"或"right"(默认"right")。生成式模型(如 GPT 系列)通常需设为"left",避免填充符号影响生成结果。 truncation_side str 截断(truncation)的方向,可选"left"或"right"(默认"right"),长文本超出最大长度时从哪边截断。 model_max_length int 分词器支持的最大文本长度(token 数),若未指定,会从模型配置中读取(如 GPT-2 默认 512,LLaMA 默认 2048)。 trust_remote_code bool 是否信任远程代码(如自定义分词器逻辑),加载非标准分词器时需设为True(默认False)。 注意事项 与模型匹配:分词器必须与模型配套使用(如Llama-2-7b模型需用其专用分词器),否则会因词汇表不匹配导致模型输入错误。 特殊符号处理:部分模型(如 LLaMA)默认没有pad_token,需手动指定(通常设为eos_token),否则填充时会报错。 最大长度限制:若文本长度超过model_max_length且未启用truncation,会抛出错误,需根据任务合理设置max_length。 快速分词器兼容性:use_fast=True是推荐设置,但部分旧模型可能仅支持慢速分词器,此时需设为False。 缓存机制:首次加载会下载分词器文件(体积较小,通常几 MB),后续加载从缓存读取,可通过cache_dir指定缓存路径。 ''' if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.37s/it]
该 SFTTrainer
支持与 peft
的原生集成,这使得使用QLoRA高效地调整LLMs变得简单。你只需要创建一个LoraConfig
并将其提供给训练器。
from peft import LoraConfig peft_config = LoraConfig( lora_alpha=16, # LoRA 的缩放参数,通常与r(秩)成比例设置 lora_dropout=0.05, # 应用于 LoRA 层的 dropout 概率,防止过拟合 r=16, # LoRA 的秩,决定了低秩矩阵的维度,数值越小参数效率越高 bias="none", # 不训练偏置参数 target_modules="all-linear", # 指定对模型中所有线性层应用 LoRA task_type="CAUSAL_LM", # 任务类型为因果语言模型(适用于文本生成任务) # lora_modules_to_save=["lm_head", "embed_tokens"] # 指定除了 LoRA 参数外还需要保存的模块 )
在你开始训练之前,你需要定义要使用的超参数 (SFTConfig
)。
from trl import SFTConfig args = SFTConfig( output_dir = "text_to_sql", # 模型保存目录及仓库ID max_seq_length=1024, # 模型输入的最大序列长度,同时用于数据集样本的打包处理 packing=True, # 将数据集中的多个样本组合成单个序列,以提高训练效率 num_train_epochs=1, # 训练的总轮数 per_device_train_batch_size=1, # 每个设备上的训练批次大小 gradient_accumulation_steps=4, # 梯度累积步数(每累积4步进行一次反向传播和参数更新) gradient_checkpointing=True, # 使用梯度检查点技术节省内存(牺牲少量速度换内存) optim="adamw_torch_fused", # 使用融合版AdamW优化器(训练速度更快) logging_steps=10, # 每10步记录一次训练日志 save_strategy="epoch", # 按轮次保存模型检查点(每个epoch结束后保存) learning_rate=2e-4, # 学习率,基于QLoRA论文的推荐设置 fp16=True if torch_dtype == torch.float16 else False, # 若使用float16精度则启用fp16 bf16=False if torch_dtype == torch.float16 else True, # 若不使用float16则启用bf16精度(平衡速度与精度) max_grad_norm=0.3, # 梯度裁剪的最大范数(防止梯度爆炸,基于QLoRA论文) warmup_ratio=0.03, # 学习率预热比例(前3%的步骤逐渐提升到设定学习率,基于QLoRA论文) lr_scheduler_type="constant", # 学习率调度策略(使用恒定学习率) report_to="tensorboard", # 训练指标的报告工具(使用tensorboard) dataset_kwargs={ "add_special_tokens": False, # 不在数据集中自动添加特殊token(因为模板中已包含) "append_concat_token": False, # 不需要在拼接样本时添加额外的分隔token } )
现在你已经拥有了创建你的SFTTrainer
所需的所有基本要素,然后开始训练你的模型。
from trl import SFTTrainer trainer = SFTTrainer( model=model, args=args, train_dataset=dataset["train"], peft_config=peft_config, processing_class=tokenizer )
Converting train dataset to ChatML: 100%|██████████| 10000/10000 [00:01<00:00, 8384.54 examples/s] Applying chat template to train dataset: 100%|██████████| 10000/10000 [00:01<00:00, 5109.31 examples/s] Tokenizing train dataset: 100%|██████████| 10000/10000 [00:07<00:00, 1253.64 examples/s] Packing train dataset: 100%|██████████| 10000/10000 [00:06<00:00, 1585.62 examples/s]
通过调用 train()
方法开始训练。
# start training, the model will be automatically saved to the hub and the output directory trainer.train() # save model the final model again to the hub trainer.save_model()
[575/575 33:27, Epoch 0/1] Step Training Loss 10 0.834300 20 0.546000 30 0.436700 40 0.404900 50 0.408100 60 0.401600 70 0.394400 80 0.392000 90 0.394000 ... 570 0.353300
在使用QLoRA时,你只训练适配器而不是整个模型。这意味着在训练过程中保存模型时,你只保存适配器的权重,而不是整个模型。如果你想保存整个模型,以便更容易地与vLLM或TGI等推理堆栈一起使用,你可以使用merge_and_unload
方法将适配器的权重合并到模型的权重中,然后使用save_pretrained
方法保存模型。这将保存一个默认模型,可以用于推理。
注意:这需要超过30GB的CPU内存。
from peft import AutoPeftModelForCausalLM # Load PEFT model on CPU model = AutoPeftModelForCausalLM.from_pretrained( args.output_dir, torch_dtype=torch.float16, low_cpu_mem_usage=True, ) # Merge LoRA and base model and save merged_model = model.merge_and_unload() merged_model.save_pretrained(args.output_dir, safe_serialization=True, max_shard_size="2GB")
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.27it/s]
4. 测试模型并运行推理
训练完成后,你需要评估和测试你的模型。你可以加载测试数据集中的不同样本,并在这些样本上评估模型。
注意:评估生成式人工智能模型不是一件容易的事,因为一个输入可以有多个正确的输出。此示例仅关注人工评估和氛围检查。
import torch from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM from transformers import Qwen2ForCausalLM, AutoConfig model_id = "./text_to_sql" # Load Model with PEFT adapter model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.float16 ) tokenizer = AutoTokenizer.from_pretrained(model_id) # load into pipeline pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) ''' 在 Hugging Face Transformers 库中,pipeline 是一个高层接口,它封装了模型加载、输入预处理、模型推理和输出后处理的完整流程,让开发者可以用极简的代码实现各种 NLP 任务(如文本生成、情感分析、翻译等)。 针对你提到的「文本生成(text-generation)」任务,pipeline 的用法详解如下 1. 初始化文本生成pipeline pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) 2. 输入提示词 prompt = "Artificial intelligence is" 3. 生成文本 result = pipe(prompt) 4. 输出结果(默认返回一个列表,包含生成的文本) print(result[0]["generated_text"]) # 示例输出:"Artificial intelligence is transforming the way we live and work, with applications ranging from healthcare to finance..." '''
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.43s/it] Device set to use cuda:0
让我们从测试中加载一个随机样本并生成一个SQL命令。
from random import randint import re # Load our test dataset rand_idx = randint(0, len(dataset["test"])) # Test on sample prompt = pipe.tokenizer.apply_chat_template(dataset["test"][rand_idx]["messages"][:2], tokenize=False, add_generation_prompt=True) outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id) # Extract the user query and original answer pattern = r'<USER_QUERY>\n(.*?)\n*</USER_QUERY>' ''' print(dataset['test'][rand_idx]['messages'][1]['content']) 给定 <USER_QUERY> 和<ScheMA>,生成相应的 SQL 命令以检索所需数据,同时考虑查询的语法、语义学和模式约束。 <SCHEMA> CREATE TABLE projects (id INT, name TEXT, continent TEXT, start_date DATE); INSERT INTO projects (id, name, continent, start_date) VALUES ('Asia Coal', 'Asia', '2018-01-01'), ('Asia Iron Ore', 'Asia', '2015-05-12'); </SCHEMA> <USER_QUERY> What is the minimum start date for mining projects in Asia? </USER_QUERY> ''' print(f"Query:\n{re.search(pattern, dataset['test'][rand_idx]['messages'][1]['content'], re.DOTALL).group(1)}") print(f"Original Answer:\n{dataset['test'][rand_idx]['messages'][2]['content']}") print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
Query: What is the number of vaccines administered per state? Original Answer: SELECT s.state_name, v.vaccine_count FROM states s JOIN vaccinations v ON s.state_id = v.state_id; Generated Answer: SELECT s.state_name, SUM(v.vaccine_count) as total_vaccines FROM states s INNER JOIN vaccinations v ON s.state_id = v.state_id GROUP BY s.state_name;
太棒了!你的模型能够根据自然语言指令生成SQL查询。
注意:如上所述,评估生成模型并不是一件简单的事情。在这个例子中,你可以使用生成的SQL与真实SQL查询的准确率作为你的指标。另一种方法是自动执行生成的SQL查询,并将结果与真实值进行比较。这将是一个更准确的指标,但需要更多的工作来设置。