weizhang2024

  博客园  :: 首页  :: 新随笔  :: 联系 :: 订阅 订阅  :: 管理

模型的基本搭建

基本工具的部署与实践

为了后续模型部分的代码更加简易的书写,我们小组先搜寻并写了一些通用工具来方便后续不同策略链的代码的实现,整个utils的目录整体如下,我将大致介绍每部分代码的大致功能,具体的使用过程在后续的具体使用时再详细阐述:

config.py:包含模型和训练的配置信息。

  • DatasetAttr 类

    该类定义了数据集的属性。用于存储数据集属性,包含数据加载来源,名称,SHA1校验码等

    • load_from:从哪里加载数据。
    • dataset_name:数据集的名称。
    • dataset_sha1:数据集的哈希值。
    • source_prefix:可能用于指定数据的来源前缀。
    • prompt_column, query_column, response_column, history_column:指定数据集中的各列名称。
  • ModelArguments 类

    此类定义了与模型相关的参数。

    配置模型路径和缓存目录:指定预训练模型的路径或标识符以及缓存目录。
    配置分词器:选择是否使用快速分词器。
    配置身份验证:选择是否使用身份验证令牌。
    配置模型版本:指定使用的模型版本。
    配置量化选项:选择量化位数、类型以及是否使用双重量化。
    配置检查点目录:指定保存检查点的目录。
    配置奖励模型路径:指定奖励模型的路径。
    配置微调选项:选择是否从上次的LoRA权重恢复训练。
    配置训练损失绘图:选择是否在微调后绘制训练损失图。
    
    • model_name_or_path:预训练模型的路径或来自 huggingface.co/models 的模型标识符。
    • cache_dir:存储从 huggingface.co 下载的预训练模型的位置。
    • use_fast_tokenizer:是否使用快速分词器(由 tokenizers 库支持)。
    • use_auth_token:是否使用运行 huggingface-cli login 时生成的令牌。
    • model_revision:要使用的特定模型版本(可以是分支名称、标签名称或提交id)。
    • quantization_bit:量化模型的位数。
    • quantization_type:用于 int4 训练的量化数据类型。
    • double_quantization:是否使用双重量化。
  • GeneratingArguments 类

    此类定义了与生成相关的参数。

    • do_sample:是否使用采样,否则使用贪婪解码。
    • temperature:用于调制下一个令牌的概率的值。
    • top_p:保持概率相加为 top_p 或更高的最小的最可能的令牌集。
    • top_k:用于 top-k 过滤的最高概率词汇令牌的数量。
    • num_beams:波束搜索的波束数。1 表示没有波束搜索。
    • max_length:生成令牌可以具有的最大长度。
    • max_new_tokens:要生成的令牌的最大数量,忽略提示中的令牌数量。
    • repetition_penalty:重复惩罚的参数。1.0 表示没有惩罚。
    • length_penalty:与基于波束的生成一起使用的长度的指数惩罚。

check.py:包含一些验证或检查逻辑。

other.py:包含一些其他的实用工具或函数。

  • 无效分数处理,将Nan和Inf归零,遍历模型的所有参数。

  • 将符合条件的 LayerNorm 层的参数数据类型转换为 float32。条件是参数维度为1且名称包含在 layer_norm_names 列表中的任意一个字符串。

  • 这个函数用于加载可训练参数的检查点文件到模型中。

  • 平滑

  • 绘制损失函数图

peft_trainer.py: PEFT 的训练策略有关。

  • 后续即将会使用

pairwise.py:与成对的数据操作或成对的损失函数有关。

  • 实现了一些用于处理成对数据(Pairwise Data)的工具和类,主要包括数据收集器和训练器。这些工具主要用于训练模型时计算成对数据的损失和准确率。
  • PairwiseDataCollatorWithPadding类:继承自 DynamicDataCollatorWithPadding,用于处理成对数据。__call__ 方法将输入特征进行填充,使得每个 batch 中的序列长度一致。它生成了 2 倍的样本,其中前 n 个样本是被选择的,后 n 个样本是被拒绝的。
  • PairwisePeftTrainer类:该类继承自 PeftTrainer,用于计算成对数据的损失。

common.py:包含一些公共的工具或函数。

  • 数量巨大,随后随使用随分析。
  • 模型加载:文件中的一部分关注于加载模型。根据不同的条件和参数,它决定从哪个检查点或路径加载模型。此外,它还处理了如何为不同的阶段准备模型,包括添加值头或加载奖励模型等。
  • 适配器初始化:模型还经历了适配器的初始化过程,这可能涉及到模型的微调或其他特定任务的适应。
  • 数据集预处理:根据训练的阶段,文件定义了如何预处理数据集。例如,对于预训练、有监督的微调、无监督的微调、对偶训练等,都有不同的预处理函数。
  • 打印数据集示例:为了给用户提供更好的可见性,文件中有一些功能可以打印数据集的示例,以便用户可以查看经过预处理后的数据如何呈现。
  • 其他实用函数:文件中还包含了其他一些实用函数,例如用于调整学习率的函数、用于计算奖励的函数等。

template.py:包含一些调用大模型的模板代码或函数。

  • 类通过不同的模板配置,可以生成不同风格和格式的对话提示和内容,适用于各种对话场景。每个模板都可以配置特定的前缀、提示格式和分隔符,并且可以选择是否使用历史记录。这使得该类非常灵活和易于扩展,能够适应多种对话系统的需求。

data_collator.py:与数据整理或预处理有关的代码。

  • DynamicDataCollatorWithPadding 类通过继承 DataCollatorWithPadding,增加了动态填充和生成注意力掩码的功能,特别适用于批处理数据的填充和序列处理。它能够处理包含 input_idslabels 的数据,确保在训练和评估时使用左侧填充,并生成相应的注意力掩码来标识填充部分。

ppo.py:与 PPO (Proximal Policy Optimization) 训练策略相关的代码。

seq2seq.py:与序列到序列模型相关的代码。与后续模型训练代码相关。文件主要涉及序列到序列模型的实现和评估。

  • 导入模块:文件开始部分导入了必要的库和工具。

  • 日志设置:使用 get_logger 函数创建日志记录器。

  • ComputeMetrics 类:此类将分词器包装到度量函数中,主要用于 Seq2SeqPeftTrainer 中。

  • call 方法:使用模型预测来计算度量标准。主要关注 Rouge 和 BLEU 分数的计算。

  • Seq2SeqPeftTrainer 类:继承自 PeftTrainer,用于序列到序列任务的训练。

  • compute_loss 方法:计算损失。

  • log_metrics 方法:记录模型的度量标准。

  • training_step 和 training_step_and_backward 方法:定义模型的训练步骤。

  • prediction_step 方法:定义模型的预测步骤,包括在生成的令牌中删除提示部分。

  • save_predictions 方法:将模型的预测结果保存到 output_dir。

PT部分

预训练(Pre-Training)是指在大规模无标注数据集上对模型进行初步训练,使模型学习到广泛的语言特征和结构。

功能

  • 捕捉基础特征:预训练阶段模型通过大规模文本数据学习语言的基本特征,如词汇、语法和上下文关系。
  • 迁移学习的基础:预训练后的模型可以用于多种下游任务,通过微调(Fine-Tuning)进一步优化特定任务的性能。
  • 减少标注数据需求:通过预训练,模型已经掌握了广泛的语言知识,微调时只需要较少的标注数据即可达到较好的效果。

初步代码实现:

# coding=utf-8
# Implements several parameter-efficient pre-training method.
# This code is inspired by
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py


import math

from utils import (
    DynamicDataCollatorWithPadding,
    PeftTrainer,
    LogCallback,
    load_pretrained,
    prepare_args,
    prepare_data,
    preprocess_data,
    plot_loss
)


def main():

    # Prepare pretrained model and dataset
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
    dataset = prepare_data(model_args, data_args)

    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
    data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)

    # Split the dataset
    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

    # Initialize our Trainer
    trainer = PeftTrainer(
        finetuning_args=finetuning_args,
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[LogCallback()],
        **trainer_kwargs
    )

    # Training
    if training_args.do_train:
        train_result = trainer.train()
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        trainer.save_model()
        if trainer.is_world_process_zero() and model_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval")

        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

前置读取分析:

由于后续的代码都需要从执行主要函数的前5行的类似代码

model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
dataset = prepare_data(model_args, data_args)

model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)

因此在此处先行进行解释,后续的部分不再解释这一部分的代码:

prepare_args

首先通过 utils.commom中的prepare_args方法来读取模型、数据、训练、微调的一些参数

def prepare_args(
        stage: Literal["pt", "sft", "rm", "ppo"]
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))

    # 使用 HfArgumentParser 解析参数,可以从 JSON 文件或命令行获取参数。
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
        model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()

    # Setup logging
    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
    # 检查各种参数的合法性,确保参数设置符合要求。
    data_args.init_for_training()

    assert stage == "sft" or (not training_args.predict_with_generate), \
        "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."

    assert not (training_args.do_train and training_args.predict_with_generate), \
        "`predict_with_generate` cannot be set as True while training."

    assert (not training_args.do_predict) or training_args.predict_with_generate, \
        "Please enable `predict_with_generate` to save model predictions."

    assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
        "Quantization is only compatible with the LoRA method."

    if model_args.checkpoint_dir is not None:
        if finetuning_args.finetuning_type != "lora":
            assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
        else:
            assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
                "Quantized model only accepts a single checkpoint."

    if model_args.quantization_bit is not None and (not training_args.do_train):
        logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")

    if training_args.do_train and (not training_args.fp16):
        logger.warning("We recommend enable fp16 mixed precision training.")

    if data_args.prompt_template == "default":
        logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")

    if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
        logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
        training_args.ddp_find_unused_parameters = False

    training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning

    # 根据优化器设置和量化位数调整训练参数。
    if model_args.quantization_bit is not None:
        if training_args.fp16:
            model_args.compute_dtype = torch.float16
        elif training_args.bf16:
            model_args.compute_dtype = torch.bfloat16
        else:
            model_args.compute_dtype = torch.float32

    # Log on each process the small summary:
    logger.info(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
        + f"  distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    transformers.set_seed(training_args.seed)

    return model_args, data_args, training_args, finetuning_args

prepare_args 函数根据训练阶段准备和检查训练参数,并设置日志记录和警告。

该函数解析命令行或 JSON 文件中的参数,确保参数的合法性,并进行必要的配置和建议。

最终返回包含模型参数、数据训练参数、序列到序列训练参数和微调参数的元组。

prepare_data

随后通过utils.common 中的 prepare_data方法来读取基本的数据

def prepare_data(
        model_args: ModelArguments,
        data_args: DataTrainingArguments
) -> Dataset:

    # 定义一个 checksum 辅助函数,用于校验文件的 SHA-1 哈希值是否匹配。如果不匹配,记录一个警告日志。
    def checksum(file_path, hash):
        with open(file_path, "rb") as datafile:
            binary_data = datafile.read()
        sha1 = hashlib.sha1(binary_data).hexdigest()
        if sha1 != hash:
            logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))

    # 定义一个字典 ext2type,将文件扩展名映射到相应的数据类型。
    ext2type = {
        "csv": "csv",
        "json": "json",
        "jsonl": "json",
        "txt": "text"
    }

    # 获取 max_samples 参数,用于限制加载的数据样本数量。
	# 定义一个列表 all_datasets,用于存储加载的所有数据集。
    max_samples = data_args.max_samples
    all_datasets: List[Dataset] = [] # support multiple datasets

    # 遍历 data_args.dataset_list 中的数据集属性,依次加载每个数据集。
	# 记录加载数据集的日志信息。
    for dataset_attr in data_args.dataset_list:

        logger.info("Loading dataset {}...".format(dataset_attr))
		
        '''
        根据 dataset_attr.load_from 的值,决定如何加载数据集:
        如果数据集来自 hf_hub,直接使用数据集名称。
        如果数据集来自 script,构建数据集路径。
        如果数据集来自文件系统,根据文件类型和数量设置 data_path 和 data_files。
        如果需要,进行文件的 SHA-1 校验。
        '''
        if dataset_attr.load_from == "hf_hub":
            data_path = dataset_attr.dataset_name
            data_files = None
        elif dataset_attr.load_from == "script":
            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
            data_files = None
        elif dataset_attr.load_from == "file":
            data_path = None
            data_files: List[str] = []

            if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                    data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))

                    if data_path is None:
                        data_path = ext2type.get(data_files[0].split(".")[-1], None)
                    else:
                        assert data_path == ext2type.get(data_files[-1].split(".")[-1], None), "file type does not match."
            elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
                data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
                data_path = ext2type.get(data_files[0].split(".")[-1], None)
            else:
                raise ValueError("File not found.")

            assert data_path, "File extension must be txt, csv, json or jsonl."

            if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
                checksum(data_files[0], dataset_attr.dataset_sha1)
            else:
                logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
        else:
            raise NotImplementedError

        '''
        使用 load_dataset 函数加载数据集。
        根据 data_args.split 获取指定的分割数据集。
        如果设置了 max_samples,限制加载的数据样本数量。
        为数据集添加列,确保所有数据集具有一致的列名。
        将处理后的数据集添加到 all_datasets 列表中。
        '''
        raw_datasets = load_dataset(
            data_path,
            data_files=data_files,
            cache_dir=model_args.cache_dir,
            use_auth_token=True if model_args.use_auth_token else None
        )
        dataset = raw_datasets[data_args.split]

        if max_samples is not None:
            max_samples_temp = min(len(dataset), max_samples)
            dataset = dataset.select(range(max_samples_temp))

        dummy_data = [None] * len(dataset)
        prefix_data = [dataset_attr.source_prefix] * len(dataset)
        for column_name, target_name in [
            ("prompt_column", "prompt"),
            ("query_column", "query"),
            ("response_column", "response"),
            ("history_column", "history")
        ]: # every dataset will have 4 columns same as each other
            if getattr(dataset_attr, column_name) != target_name:
                if getattr(dataset_attr, column_name):
                    dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
                else: # None or empty string
                    dataset = dataset.add_column(target_name, dummy_data)
        dataset = dataset.add_column("prefix", prefix_data)
        all_datasets.append(dataset)

    '''
    如果只有一个数据集,直接返回该数据集。
	如果有多个数据集,将它们合并后返回。
    '''
    if len(data_args.dataset_list) == 1:
        all_datasets = all_datasets[0]
    else:
        all_datasets = concatenate_datasets(all_datasets)

    return all_datasets

load_pretrained

函数用于加载预训练的模型和分词器,并根据不同的训练或推理阶段进行相应的配置。它支持模型量化、自动类注册和适配器初始化,并在需要时处理奖励模型和近端策略优化。函数最终返回配置好的模型和分词器。

def load_pretrained(
        model_args: ModelArguments,
        finetuning_args: FinetuningArguments,
        is_trainable: Optional[bool] = False,
        stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
    r"""
    Loads pretrained model and tokenizer.
    
        stage:当前阶段,可以是 "pt"(预训练)、"sft"(监督微调)、"rm"(奖励模型)或 "ppo"(近端策略优化)。
        检查 is_trainable 和 checkpoint_dir 参数,并根据情况设置 finetuning_args。
        验证阶段是否合适,确保 RM 和 PPO 阶段只能与 LoRA 方法一起使用。
        
    Support both training and inference.
    """
    if (not is_trainable) and model_args.checkpoint_dir is None:
        logger.warning("Checkpoint is not found at evaluation, load the original model.")
        finetuning_args = FinetuningArguments(finetuning_type="none")

    assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
        "RM and PPO training can only be performed with the LoRA method."

    # 设置加载配置的参数。从预训练模型路径加载分词器,并设置填充标记 ID。
    config_kwargs = {
        "trust_remote_code": True,
        "cache_dir": model_args.cache_dir,
        "revision": model_args.model_revision,
        "use_auth_token": True if model_args.use_auth_token else None,
    }

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        use_fast=model_args.use_fast_tokenizer,
        padding_side="left",
        **config_kwargs
    )
    if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
        tokenizer.pad_token_id = 0 # set as the <unk> token

    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    is_mergeable = True

    # Quantization configurations (using bitsandbytes library).
    # 加载模型配置。根据量化位数设置相应的量化配置,使用 bitsandbytes 库进行 4 位或 8 位量化。设置 device_map,确保模型在正确的设备上加载。
    if model_args.quantization_bit is not None:
        if model_args.quantization_bit == 8:
            require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
            config_kwargs["load_in_8bit"] = True
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True,
                llm_int8_threshold=6.0
            )

        elif model_args.quantization_bit == 4:
            require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
            require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
            require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
            require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
            config_kwargs["load_in_4bit"] = True
            config_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=model_args.compute_dtype,
                bnb_4bit_use_double_quant=model_args.double_quantization,
                bnb_4bit_quant_type=model_args.quantization_type
            )

        is_mergeable = False
        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
        logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))

    # 如果不是可训练模式,设置 device_map 为 "auto" 以自动分配设备。根据是否有检查点目录选择要加载的模型路径。加载预训练模型,设置数据类型和低 CPU 内存使用。
    if not is_trainable: # `device_map=auto` should be used for inference only
        config_kwargs["device_map"] = "auto"

    if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
        model_to_load = model_args.checkpoint_dir[0]
    else:
        model_to_load = model_args.model_name_or_path

    # Load and prepare pretrained models (without valuehead).
    model = AutoModelForCausalLM.from_pretrained(
        model_to_load,
        config=config,
        torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
        low_cpu_mem_usage=True,
        **config_kwargs
    )

    # Register auto class to save the custom code files.
    # 注册自动类以保存自定义代码文件。初始化模型适配器,根据是否为可训练模式进行相应配置
    if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
        config.__class__.register_for_auto_class()
    if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
        tokenizer.__class__.register_for_auto_class()
    if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
        model.__class__.register_for_auto_class()

    # Initialize adapters
    model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
    model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)

    # 如果阶段是 RM 或 PPO,添加值头以评估奖励。加载奖励模型权重并进行相应配置。
    if stage == "rm" or stage == "ppo": # add value head
        model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

        if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
            logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
            if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
                model.v_head.load_state_dict({
                    "summary.weight": getattr(model, "reward_head_weight"),
                    "summary.bias": getattr(model, "reward_head_bias")
                })

        if stage == "ppo": # load reward model
            assert is_trainable, "PPO stage cannot be performed at evaluation."
            assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
            logger.info("Load reward model from {}".format(model_args.reward_model))
            model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
            assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."

    # 如果不是可训练模式,固定所有模型参数,并根据需要将模型数据类型转换为 fp16。打印可训练参数信息。返回加载的模型和分词器。
    if not is_trainable:
        model.requires_grad_(False) # fix all model params
        model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16

    print_trainable_params(model)

    return model, tokenizer

preprocess_data

preprocess_data 函数根据不同的训练阶段对数据集进行预处理。它定义了多个预处理函数,处理预训练数据、监督数据、无监督数据和成对数据。根据不同的阶段选择合适的预处理函数,并使用 dataset.map 对数据集进行批处理预处理。预处理完成后,打印示例并返回处理好的数据集。

def preprocess_data(
        dataset: Dataset,
        tokenizer: PreTrainedTokenizer,
        data_args: DataTrainingArguments,
        training_args: Seq2SeqTrainingArguments,
        stage: Literal["pt", "sft", "rm", "ppo"]
) -> Dataset:

    column_names = list(dataset.column_names)
    prompt_template = Template(data_args.prompt_template)

    # support question with a single answer or multiple answers 定义一个 get_dialog 函数,用于根据示例生成对话内容。
    def get_dialog(examples):
        for i in range(len(examples["prompt"])):
            if examples["prompt"][i] and examples["response"][i]:
                query, answer = examples["prompt"][i], examples["response"][i]
                query = query + "\n" + examples["query"][i] if examples["query"][i] else query
                prefix = examples["prefix"][i] if examples["prefix"][i] else ""
                dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
                yield dialog

    # 定义一个 preprocess_pretrain_dataset 函数,用于预训练数据集的预处理。将文本拼接成指定长度的块,并生成输入和标签。
    def preprocess_pretrain_dataset(examples):
        # build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
        text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
        concatenated_ids = list(chain(*text_ids))
        total_length = len(concatenated_ids)
        block_size = data_args.max_source_length - 1
        # we drop the small remainder, and if the total_length < block_size, we exclude this batch
        total_length = (total_length // block_size) * block_size
        # split by chunks of max_source_length
        result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
                  for i in range(0, total_length, block_size)]
        return {
            "input_ids": result,
            "labels": result.copy()
        }

    def preprocess_supervised_dataset(examples):
        # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
        # for input with history, we build multiple input-label pairs just like:
        # 定义一个 preprocess_supervised_dataset 函数,用于监督数据集的预处理。生成格式为 <bos> X Y <eos> 的输入和 <ignore> ... <ignore> Y <eos> 的标签。
        # https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
        model_inputs = {"input_ids": [], "labels": []}
        max_length = data_args.max_source_length + data_args.max_target_length

        for dialog in get_dialog(examples):
            input_ids, labels = [], []

            for i in range(len(dialog) // 2):
                source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=True)
                target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)

                if len(source_ids) > data_args.max_source_length:
                    source_ids = source_ids[:data_args.max_source_length]
                if len(target_ids) > data_args.max_target_length - 1: # eos token
                    target_ids = target_ids[:data_args.max_target_length - 1]

                if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
                    break

                input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
                labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]

            model_inputs["input_ids"].append(input_ids)
            model_inputs["labels"].append(labels)

        return model_inputs

    def preprocess_unsupervised_dataset(examples):
        # build inputs with format `<bos> X` and labels with format `<bos> Y`
        # 定义一个 preprocess_unsupervised_dataset 函数,用于无监督数据集的预处理。生成格式为 <bos> X 的输入和 <bos> Y 的标签。
        model_inputs = {"input_ids": [], "labels": []}

        for dialog in get_dialog(examples):
            prompt, answer = "".join(dialog[:-1]), dialog[-1]

            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
            target_ids = tokenizer.encode(text=answer, add_special_tokens=True)

            if len(source_ids) > data_args.max_source_length:
                source_ids = source_ids[:data_args.max_source_length]
            if len(target_ids) > data_args.max_target_length:
                target_ids = target_ids[:data_args.max_target_length]

            model_inputs["input_ids"].append(source_ids)
            model_inputs["labels"].append(target_ids)

        return model_inputs

    def preprocess_pairwise_dataset(examples):
        # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
        # 定义一个 preprocess_pairwise_dataset 函数,用于成对数据集的预处理。生成格式为 <bos> X Y1 <eos> 和 <bos> X Y2 <eos> 的输入对。
        model_inputs = {"accept_ids": [], "reject_ids": []}
        for dialog in get_dialog(examples):
            prompt, answer = "".join(dialog[:-1]), dialog[-1]

            source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
            accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
            reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)

            if len(source_ids) > data_args.max_source_length:
                source_ids = source_ids[:data_args.max_source_length]
            if len(accept_ids) > data_args.max_target_length - 1: # eos token
                accept_ids = accept_ids[:data_args.max_target_length - 1]
            if len(reject_ids) > data_args.max_target_length - 1: # eos token
                reject_ids = reject_ids[:data_args.max_target_length - 1]

            accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
            reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]

            model_inputs["accept_ids"].append(accept_ids)
            model_inputs["reject_ids"].append(reject_ids)
        return model_inputs

    def print_supervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
        print("label_ids:\n{}".format(example["labels"]))
        print("labels:\n{}".format(
            tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
                             skip_special_tokens=False)
        ))

    def print_pairwise_dataset_example(example):
        print("accept_ids:\n{}".format(example["accept_ids"]))
        print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
        print("reject_ids:\n{}".format(example["reject_ids"]))
        print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))

    def print_unsupervised_dataset_example(example):
        print("input_ids:\n{}".format(example["input_ids"]))
        print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))

    if stage == "pt":
        preprocess_function = preprocess_pretrain_dataset
    elif stage == "sft":
        preprocess_function = preprocess_unsupervised_dataset \
            if training_args.predict_with_generate else preprocess_supervised_dataset
    elif stage == "rm":
        preprocess_function = preprocess_pairwise_dataset
    elif stage == "ppo":
        preprocess_function = preprocess_unsupervised_dataset

    # 使用 dataset.map 函数对数据集进行预处理。根据阶段选择并打印预处理后的示例。返回预处理后的数据集。
    with training_args.main_process_first(desc="dataset map pre-processing"):
        dataset = dataset.map(
            preprocess_function,
            batched=True,
            num_proc=data_args.preprocessing_num_workers,
            remove_columns=column_names,
            load_from_cache_file=not data_args.overwrite_cache,
            desc="Running tokenizer on dataset"
        )

        if stage == "pt":
            print_unsupervised_dataset_example(dataset[0])
        elif stage == "sft":
            print_supervised_dataset_example(dataset[0])
        elif stage == "rm":
            print_pairwise_dataset_example(dataset[0])
        elif stage == "ppo":
            print_unsupervised_dataset_example(dataset[0])

        return dataset

DynamicDataCollatorWithPadding

DynamicDataCollatorWithPadding 类通过继承 DataCollatorWithPadding,增加了动态填充和生成注意力掩码的功能,特别适用于批处理数据的填充和序列处理。它能够处理包含 input_idslabels 的数据,确保在训练和评估时使用左侧填充,并生成相应的注意力掩码来标识填充部分。

综合考虑,这部分的实现如下功能:

准备预训练模型和数据集参数(从命令行或 JSON 文件中获取参数、配置日志记录、检查参数的合法性、根据不同的阶段设置相应的训练参数)

准备数据集(根据数据参数中的数据集列表,加载指定的数据集。支持从 Hugging Face Hub、脚本或本地文件系统加载数据集。检查并验证数据集文件的完整性。加载数据集,并根据需要限制样本数量。统一数据集的列名,确保数据集具有一致的结构。)

加载预训练模型和分词器(设置加载配置的参数。从预训练模型路径加载分词器,并设置填充标记 ID。根据量化位数设置相应的量化配置,使用 bitsandbytes 库进行 4 位或 8 位量化。根据是否有检查点目录选择要加载的模型路径。加载预训练模型,设置数据类型和低 CPU 内存使用。注册自动类以保存自定义代码文件。初始化模型适配器,根据是否为可训练模式进行相应配置。处理 RM 和 PPO 阶段时,添加值头以评估奖励。如果不是可训练模式,固定所有模型参数,并根据需要将模型数据类型转换为 fp16。)

预处理数据集(根据不同的训练阶段定义不同的预处理函数:预训练数据集预处理函数(preprocess_pretrain_dataset)。监督数据集预处理函数(preprocess_supervised_dataset)。无监督数据集预处理函数(preprocess_unsupervised_dataset)。成对数据集预处理函数(preprocess_pairwise_dataset)。选择合适的预处理函数,并使用 dataset.map 对数据集进行批处理预处理。打印预处理后的数据集示例(用于调试和验证)。返回预处理后的数据集。)

准备数据填充器(继承自 DataCollatorWithPadding,增加了动态填充和生成注意力掩码的功能。根据是否忽略填充标记计算损失,设置标签填充值。在批处理数据时,为序列生成注意力掩码,并填充到批次中最长的序列长度。)

后置训练部分

后续的训练过程首先是一个简单的利用 python.dataset.train_test_split方法来进行的交叉验证分割。

    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

随后便是定义模型,训练模型,保存模型,评估模型的经典步骤,其中定义模型该阶段采用的我们在 peft_trainer 中实现的集成于python.transformers中的Seq2SeqTrainer 模型。

class PeftTrainer(Seq2SeqTrainer):
    r"""
    Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
    继承自Seq2SeqTrainer,实现了自定义模型保存和加载功能
    """

    def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
        super().__init__(**kwargs)
        self.finetuning_args = finetuning_args
        if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
            logger.warning("Previous log file in this folder will be deleted.")
            os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))

    def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
        r"""
        Saves trainable parameters as model checkpoint.

        This function will only be executed at the process zero.

        Subclass and override to inject custom behavior. It should not be directly used by external scripts.

        保存训练参数作为模型检查点。该方法只在主进程中执行。
        创建输出目录并确保其存在。
        通过 unwrap_model 获取基础模型。如果模型有 pretrained_model 属性(如在使用 LoRA 时),则获取它并保存 v_head 的状态字典。
        根据微调类型(lora 或其他)保存模型:
        如果是 LoRA 微调,调用 save_pretrained 方法保存基础模型的参数。
        如果是全参数或冻结微调,临时设置 use_cache 属性为 True,保存模型后再将其恢复为 False。
        如果有 tokenizer,则将其保存到输出目录。
        将训练参数和微调参数分别保存到 TRAINING_ARGS_NAME 和 FINETUNING_ARGS_NAME 文件中。
        """
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")
        model = unwrap_model(self.model)

        if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
            backbone_model = getattr(model, "pretrained_model")
            torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
        else:
            backbone_model = model

        if self.finetuning_args.finetuning_type == "lora":
            backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
        else: # freeze/full tuning
            backbone_model.config.use_cache = True
            backbone_model.save_pretrained(
                output_dir,
                state_dict=get_state_dict(backbone_model),
                safe_serialization=self.args.save_safetensors
            )
            backbone_model.config.use_cache = False
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)

        with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
            f.write(self.args.to_json_string() + "\n")
        self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))

    def _load_best_model(self):
        r"""
        Loads trainable parameters from model checkpoint.

        Subclass and override to inject custom behavior. It should not be directly used by external scripts.

        从最佳检查点加载可训练参数。该方法只在主进程中执行。
        记录加载最佳模型的日志信息。
        通过 unwrap_model 获取基础模型。如果模型有 pretrained_model 属性,则获取它作为 backbone_model。
        根据微调类型(lora 或其他)加载模型:
        如果是 LoRA 微调,加载适配器并设置 v_head 参数。
        如果是全参数或冻结微调,调用 load_trainable_params 函数加载可训练参数。
        """
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

        model = unwrap_model(self.model)
        backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model

        if self.finetuning_args.finetuning_type == "lora":
            backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
            if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
                model.v_head.load_state_dict({
                    "summary.weight": getattr(model, "reward_head_weight"),
                    "summary.bias": getattr(model, "reward_head_bias")
                })
        else: # freeze/full-tuning
            load_trainable_params(backbone_model, self.state.best_model_checkpoint)

其中有关LoRA相关的已经在之前的博客中提及。

其中不继承的函数 _load_best_model _save已经在上述代码中阐述。相关继承的方向,Seq2SeqTrainer 是 Hugging Face transformers 库中用于训练序列到序列(Seq2Seq)模型的类。它提供了一些方法和属性,方便进行模型训练和评估。以下是 Seq2SeqTrainer 的一些关键功能:

  • 训练和评估:提供训练和评估的核心方法,如 trainevaluate
  • 保存和加载模型:提供保存和加载模型检查点的方法,如 save_modelload_model
  • 分布式训练:支持多 GPU 和分布式训练。
  • 混合精度训练:支持使用 fp16bf16 进行混合精度训练。
  • 日志记录和监控:集成了日志记录和监控功能,支持 TensorBoard 等工具。

PeftTrainer 通过继承 Seq2SeqTrainer,扩展了其功能,支持参数高效的检查点保存和加载。它在初始化时进行一些必要的设置,并重写了 _save_load_best_model 方法,以支持 LoRA 等微调方法。通过这些扩展,PeftTrainer 可以更高效地管理模型的训练和微调过程。

posted on 2024-05-31 12:43  weiZhang2024  阅读(102)  评论(0)    收藏  举报