如何微调-Granite-Vision-2B-以击败-90B-模型--洞察与经验教训

如何微调 Granite-Vision 2B 以击败 90B 模型——洞察与经验教训

原文:towardsdatascience.com/how-i-fine-tuned-granite-vision-2b-to-beat-a-90b-model-insights-and-lessons-learned/

微调大型语言或视觉语言模型是一种强大的技术,它解锁了它们在特定任务上的潜力。然而,尽管这些方法非常有效,但由于它们的高计算成本和需要具有大 VRAM 的 GPU——这些资源只有一小部分最终用户可以访问——这些方法通常对许多用户来说遥不可及。

在这个项目中,我微调了 IBM 的Granite-Vision 2B,这是一个相对较小但功能强大的视觉语言模型,以应对将表格图像转换为干净的、结构化的 HTML 代码的挑战。

使这个项目特别令人兴奋的是,微调是在消费级 GPU——NVIDIA RTX 4070 Ti Super——上进行的,而结果却是这个 20 亿参数的模型能够在图像到文本生成任务上超越许多更大的模型,包括meta-llama/Llama-3.2–90B-Vision。这一成功不仅证明了参数高效微调方法(如 LoRA)的力量,也突出了构建针对特定问题量身定制的小型模型的实际价值。

在这篇文章中,我将向您介绍这项工作的动机,包括模型和数据集的选择,我采用的定制 HTML 相似度指标,实验和结果,以及在整个过程中获得的关键洞察和经验教训。无论您对视觉语言模型、微调技术还是实际的人工智能应用感兴趣,我都希望这次旅程能为您提供有用的收获。用于此项目的微调代码是从HuggingFace 的 Granite Vision 微调食谱改编的,该食谱由 Eli Schwartz 撰写,他反过来又从 Sergio Paniego 那里改编了原始代码。

动机

在进行检索增强生成(RAG)项目时,我遇到了一个主要挑战:从 PDF 中准确提取大型和复杂的表格,尤其是在这些表格以图像形式出现时。尽管尝试了不同的方法——包括像 Unstructured 这样的工具以及 Meta 的 Llama 90B 这样的大型视觉语言模型——但结果往往达不到所需的准确性。

这使我考虑了一种不同的方法:一个专注于表格理解和提取的专用小型视觉语言模型。这样的模型可以作为 RAG 管道的专用预处理步骤,显著提高依赖于准确表格提取的 RAG 管道。

大约与此同时,IBM 发布了 Granite-Vision 2B——一个大小和功能平衡恰到好处的视觉语言模型。它足以处理复杂的表格,同时体积足够小,可以在具有 16 GB VRAM 的消费级 GPU 上进行微调。这使得它成为我项目的理想候选者。

任务:图像到 HTML(表格提取)

一个重要的设计选择是目标格式:HTML。通过将表格转换为干净的 HTML 代码,我们获得了一种结构化和广泛支持的表示形式,可以轻松转换为其他格式。例如,HTML 表格可以轻松导入到 Pandas 等数据分析工具作为数据框,使得下游处理和分析更加高效。

原计划是通过提取 HTML 表格标签,将它们渲染为图像,并将每个图像与其对应的 HTML 代码配对来构建一个自定义数据集。幸运的是,我找到了一个解决方案:PubTabNet-HTML数据集,该数据集包含超过 568,000 个图像-HTML 配对,远远超过了这个项目所需的数量。

PubTabNet 是由 IBM 开发的,基于PubMed Central Open Access Subset(商业用途集合)中的科学文章。表格是通过将文章的 PDF 和 XML 版本对齐来提取的。标注(即 HTML 标签)根据社区数据许可协议 – 宽泛 – 第 1.0 版进行许可,虽然IBM 不拥有图像,但它们的使用符合PMC 开放获取子集使用条款。这使得数据集适合研究和商业应用,前提是遵守许可条款。

自定义指标:HTML 相似度

标准文本相似度指标如 BLEU 或 ROUGE 对于评估 HTML 表格生成是不够的,因为它们主要关注表面级别的文本匹配,而忽略了 HTML 代码的重要结构和风格方面。

为了更好地捕捉生成的 HTML 表格的质量,我采用了一个自定义的HTML 相似度指标,该指标结合了多个互补的组件,其中最重要的(样式和结构)是从niteru导入的:

  • 样式相似度(S):提取每个 HTML 文档的 CSS 类,并计算类集合的 Jaccard 相似度。

  • 结构相似度(T):使用 HTML 标签的序列比较来计算相似度。

  • 内容相似度(C):基于提取的表格纯文本内容的归一化编辑距离。

  • 标记重叠相似度(J):内容标记集合之间的 Jaccard 相似度。

最终的相似度得分M是这些组件的加权总和:

我手动测试了该指标在各种示例输出上,迭代调整权重系数以更好地捕捉有意义的相似性。这个过程产生了一个平衡的评估,公平地奖励准确的表格结构和样式,以及精确的文本内容。Python 实现如下:

from torchmetrics.text import EditDistance
from niteru import style_similarity, structural_similarity

ed_distance = EditDistance()

def extract_table_text(html):
    """Extracts only the text from an HTML table in row-wise space-separated format."""
    soup = BeautifulSoup(html, "html.parser")
    table = soup.find("table")  # Find the first table
    if not table:
        return ""
    # Extract rows and join cells with spaces
    return "\n".join(" ".join(cell.get_text(strip=True) for cell in row.find_all(["th", "td"])) for row in table.find_all("tr"))

def extract_html_table(html):
    """Extracts html table from text"""
    match = re.search(r'<table\b.*?</table>', html, re.DOTALL | re.IGNORECASE)
    if match:
        table_html = match.group()
        return table_html
    else:
        return html

def html_similarity(html1, html2):
    html1 = extract_html_table(html1)
    html2 = extract_html_table(html2)
    # Compute individual similarity scores
    style_sim = style_similarity(html1, html2)  # Assume returns [0,1]
    struct_sim = structural_similarity(html1, html2)  # Assume returns [0,1]
    txt1, txt2 = extract_table_text(html1), extract_table_text(html2)
    content_sim = 1 - (ed_distance(txt1, txt2) /
                                   max(len(txt1), len(txt2) + 1e-10))  # Avoid division by zero
    jaccard_sim = 1 - (len(set(txt1.split()).intersection(set(txt2.split()))) /
                        len(set(txt1.split()).union(set(txt2.split()))) + 1e-10)

    # Weighted sum of the similarities
    final_score = (0.10 * style_sim) + (0.40 * struct_sim) + (0.30 * content_sim) + (0.20 * jaccard_sim)
    # Ensure final score is in [0,1]
    final_score = max(0, min(1, final_score))
    return final_score

该指标还包括一个基于正则表达式的函数,用于提取<table>标签内的 HTML 内容。这是必要的,因为其中一个参考模型只生成了表格结构之外的完整或不完整的 HTML。通过严格关注表格内容,该指标为模型提供了更公平和有意义的评估。

开发这样一个定制的评估指标对于可靠地跟踪模型改进和将性能与参考模型进行基准测试至关重要。

训练设置

为了高效地在我的 NVIDIA RTX 4070 Ti Super 上微调模型,该显卡拥有 16 GB VRAM,我使用了LoRA(低秩自适应)。这使我能够只更新少量参数,显著减少 GPU 内存使用。实际上,在训练过程中,模型只使用了大约一半的可用 VRAM——有足够的余量来处理更长的序列,但不足以处理超过一个批次。此外,与 QLoRA 等方法相比,LoRA 通常训练速度更快。

LoRA 设置

我使用了以下 LoRA 配置:

# Setup LoRA
target_modules = []
for layer_type in layers_to_tune:
    target_modules.extend(
        name for name, _ in model.named_modules()
        if (layer_type in name) 
        and '_proj' in name
    )
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=target_modules,
    use_dora=True,
    init_lora_weights="gaussian"
)

关键点:

  • r=16: 这个低秩维度在模型容量和 GPU 内存使用之间提供了良好的平衡。

  • use_dora=True: DoRA(权重分解低秩自适应)通过将预训练权重分解为幅度和方向分量,提高了 LoRA 的学习能力和稳定性,帮助模型更好地接近全微调的能力——所有这些都不需要增加推理开销。表现略优于默认设置。

  • init_lora_weights="gaussian": 没有特别的原因,我不想在这个参数上做实验。

  • target_modules: 这个灵活的设置允许根据实验选择性地针对视觉层、语言层或两者,具体取决于实验。在实践中,视觉层保持未受影响——即使use_dora=False——因为 DoRA 目前仅支持嵌入、线性层和 Conv2d 层。因此,我只微调了语言层。

数据集设置

在我的初步实验中,我不断遇到内存不足(OOM)错误——尽管在加载模型后仍有大量的可用 GPU VRAM(大约还有 4GB 空闲)。训练过程中没有内存峰值,但崩溃总是在相同的训练步骤发生。

经过一些调查,我发现问题是由大表格引起的,这导致了非常长的标记序列。为了解决这个问题,我调整了max_seq_length参数,并过滤掉了超过这个限制的样本。经过实验,我发现使用max_seq_length = 1024可以让我可靠地微调模型,而不会触发 OOM 错误。

为了过滤掉过大的表格,我编写了一个简单的数据处理函数,该函数:

  • 过滤掉 HTML 标记长度超过max_seq_length的样本

  • 自动平衡训练和测试样本的数量

  • 使用流式处理来避免将整个数据集加载到内存中(PubTabNet-HTML 相当大,在磁盘上大约有 10 GB)

def load_process_filter_dataset(dataset, max_seq_length, num_train_images, num_test_images, system_message):
    global processor
    ds = load_dataset(dataset, split='train', streaming=True)
    max_html_tokens = max_seq_length - len(processor.tokenizer.tokenize(system_message))
    num_total_needed = num_train_images + num_test_images
    filtered_samples = []
    p_bar = tqdm(total=num_total_needed, desc="Filtering dataset samples")
    for sample in ds:
        processed = process_and_filter_example(sample, max_html_tokens)
        if processed:
            filtered_samples.append(processed)
            p_bar.update(1)
        if len(filtered_samples) >= num_total_needed:
            break
    p_bar.close()
    # Convert to in-memory dataset
    ds_filtered = Dataset.from_list(filtered_samples)
    # Split into train/test
    ds_train = ds_filtered.select(range(num_train_images))
    ds_test = ds_filtered.select(range(num_train_images, num_total_needed))
    return ds_train, ds_test

def process_and_filter_example(example, max_html_tokens):
    global processor
    extracted_table = extract_html_table(example['html_table'])
    token_count = len(processor.tokenizer.tokenize(extracted_table))
    if token_count < max_html_tokens:
        example['html_table'] = extracted_table
        return example
    return None

最终配置包括num_train_images=10000num_test_images=250,以计算评估损失。

微调配置

对于训练,我使用了Transformers SFTTrainer来微调模型:

# Training arguments
    training_args = SFTConfig(
        output_dir=f"src/models/{model_name.split('/')[-1].replace('-', '_', 1).split('-')[0]}/checkpoints/{experiment_name}",
        num_train_epochs=1,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=gradient_accumulation_steps,
        max_seq_length=max_seq_length,
        warmup_steps=10,
        learning_rate=3e-4,
        weight_decay=0.01,
        logging_strategy="steps",
        eval_strategy='steps',
        logging_steps=25,
        save_strategy="steps",
        save_steps=50,
        save_total_limit=1,
        greater_is_better=False,
        load_best_model_at_end=True,
        optim="adamw_torch_fused",
        bf16=True,
        push_to_hub=False,
        report_to="wandb" if not debug else "none",
        remove_unused_columns=False,
        gradient_checkpointing=True,
        dataset_text_field="",
        dataset_kwargs={"skip_prepare_dataset": True},
        dataset_num_proc=8
    )

# Setup Trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        data_collator=collate_fn,
        peft_config=peft_config,
        processing_class=processor.tokenizer
    )

关键点:

  • num_train_epochs=1:数据集非常大,为了高效地运行多个实验,我选择只训练一个完整周期,同时最大化每个样本和训练样本的数量。

  • per_device_train_batch_size=1:更大的批量大小无法适应 GPU 内存,除非显著减少 max_seq_length——这将损害大型表格上的性能。保持更长的序列对于这个任务来说更为重要。

  • gradient_accumulation_steps=8:用于有效地模拟更大的批量大小,并帮助稳定学习过程,补偿较小的物理批量。这是最终值,但也尝试过gradient_accumulation_steps=4

  • optim="adamw_torch_fused"bf16=True:这些设置利用现代 NVIDIA 架构(Ada Lovelace)来加速训练并减少内存使用——正如为这种硬件所推荐的那样。

评估损失解决方案

在开发项目的时候,Transformers + LoRA 集成中存在一个已知问题,在训练期间使用验证数据集进行评估时会导致错误。幸运的是,有一个经过社区测试的解决方案可用(尽管尚未合并到主分支),我在实验中成功地使用了这个修复。

评估(推理)设置

用于最终评分的评估数据集与训练期间使用的eval_dataset完全独立。它由500 个随机选择的图像组成,其中没有任何一个图像包含在train_dataset或训练的eval_dataset中。

一旦微调完成,我使用了基于最低评估损失的最佳模型检查点来对这些 500 个样本进行推理。

初始时,我尝试通过在基模型之上简单地加载 LoRA/DoRA 适配器来进行推理。然而,我发现当DoRA 适配器未合并到模型权重中时(如官方 PEFT 文档中所述),推理速度非常慢。实际上,在这种配置下生成一个测试随机样本大约需要90 秒

为了解决这个问题,我将适配器权重合并到基模型中——这是推荐的实践——合并后,推理速度显著提高:对于相同样本,推理时间缩短到约 20 秒,使得全面评估运行变得更加实用。

用于与我微调的模型进行比较的参考模型是:

  • meta-llama/Llama-3.2–90B-Vision: Meta 的巨型90 亿参数模型——我通过专业化和对更小 VLM 参数高效微调来超越的主要基线。

  • KennethTM/pix2struct-base-table2html: 从谷歌的pix2struct-base微调的一个更小的模型,针对与我在这项项目中使用的完全相同的数据集进行了高度专业化。由于其较小的尺寸,开发者能够对更多样本进行训练,并在更长的训练周期中进行训练——展示了使用较小、针对特定任务的模型的关键优势。

这两个基线使我能够基准测试基于缩放的性能(与 90B 模型相比)和专业化效率(与较小、专门的 Pix2Struct 模型相比)。

实验 & 结果

共进行了 9 次实验,每次迭代性地修改一个或两个组件。目标是了解每个变化对模型性能的影响,逐渐细化设置,以实现与参考模型相比最佳的HTML 相似度得分。

实验过程是逐步进行的:每当一个变化改善了结果,它就会被纳入下一轮实验,并继续探索新的变化。

实验主要集中在调整以下组件:

  1. 视觉与语言层
  • 1.1 lang_only

  • 1.2 vision_only

  • 1.3 lang_vision

2. 真实输出格式

  • 2.1 lang_table_only

3. 训练框架

  • 3.1 lang_table_unsloth

  • 3.2 vision_table_unsloth

4. 梯度累积

  • 4.1 lang_table_only_2

5. 提示格式

  • 5.1 lang_table_only_3

6. 梯度累积与数据集大小

  • 6.1 lang_table_only_4

评估损失和HTML 相似度指标都被用来评估模型性能,我发现它们之间有很好的相关性——证实了 HTML 相似度是衡量模型学习任务好坏的良好代理。

在深入每个实验的结果之前,让我们首先看看 训练期间的 GPU 内存利用率,这通常是确定模型是否可以在消费级硬件上进行微调的最关键因素。

GPU 内存利用率训练期间 | 图由作者来自 wandb.ai

如图中所示,GPU 利用率在整个训练过程中保持稳定——平均大约为 75% VRAM 使用率,或者在我的 GPU 上大约为 12 GB。VRAM 使用的大部分(约 5.5 GB)是冻结的模型权重。LoRA 梯度 + 优化器状态占用非常少(<< 1 GB)。激活 + 额外开销应填满剩余部分(约 5–6 GB),这取决于 batch_size 和 max_seq_length

第一次运行: lang_only

这个实验使用了以下初始组件/参数:

这些是第一次实验的起始值。在后续运行中,我根据改进的方法修改了许多值。这个第一次实验仅关注调整 语言层,同时训练模型预测完整的原始 HTML 输出——包括<table>标签内和周围的全部内容。

由于这是第一次运行,我将包括 训练损失曲线 以说明其行为。对于后续实验,我将省略此图——因为运行之间的行为相似,只有细微的差异。实际上, 评估损失 对于比较实验的性能更有用。

训练损失 | 图由作者来自 wandb.ai

关于日志配置的一个重要注意事项:logging_steps=25 表示训练损失仅在每 25 步之后记录,其中每个记录的值是gradient_accumulation_steps=4的平均值。因此,损失的最大下降出现在第二个记录点——这是大多数初始学习发生的地方。之后,模型继续以更缓慢的趋势学习,具体取决于训练样本的难度。

现在,让我们来看看 评估损失

验证损失 1 | 图由作者来自 wandb.ai

由于我们在同一组 250 个验证样本上评估,评估损失曲线为我们提供了一个更稳定和有意义的模型学习视图——并将作为未来运行比较的基准。

在这里,我们观察到整个训练过程中明显的持续下降趋势。初始损失接近 0.03,随着训练的进行稳步改善,最终稳定在 0.015 以下。

与更易变的训练损失相比,这条曲线的平滑性反映了验证集的规则结构,并证实了模型在未见过的样本上泛化良好,即使在较小的批量大小和单个训练周期的情况下。

现在,让我们比较这个微调模型在HTML 相似度指标上与参考模型的性能:

图片

如我们所见,这个第一次实验已经带来了显著的性能提升——将基础 Granite-Vision 2B 模型提升了很大幅度(+0.18),并且在这个专业任务上明显优于 LLaMA 90B Vision。只有 Pix2Struct 在这个阶段仍保持微弱领先。

第二次运行:vision_only

在这次运行中,没有太多可以分析的内容。我测试了几个可能解锁视觉层学习的变体——包括大幅提高学习率——但都没有成功。

尽管基础代码表明微调视觉层应该是可能的,但在实践中我发现在这个设置中它不起作用。以下评估损失曲线证实了这一点——损失在整个训练过程中保持恒定。为了避免浪费计算资源,我提前停止了运行:

图片

验证损失 2 | 图片由作者来自 wandb.ai 提供

此外,与之前的lang_only实验相比,这次训练的速度明显更快——这表明语言层(包含模型的大部分参数)保持冻结状态,只有小的视觉层正在被处理:

图片

验证样本每秒 1 | 图片由作者来自 wandb.ai 提供

第三次运行:lang_vision

到这个时候,很明显只有语言层被有效地训练。在这个lang_vision运行中——其中选择了语言和视觉层——我期望得到与lang_only相似的结果。

的确,评估损失曲线证实了这一点,显示出与lang_only几乎相同的行为:

图片

验证损失 3 | 图片由作者来自 wandb.ai 提供

一旦这一点清楚,我又提前停止了训练以节省资源,并开始测试新的方法。

第四次运行:lang_table_only

这个实验修改了以下组件:

图片

这次运行的目标是训练模型只预测表格内容,而不包含任何周围的 HTML 包装代码。这种方法可以帮助提高学习——通过移除不必要的标记——并且使训练行为更接近 Pix2Struct 的模型。

此外,通过移除包装 HTML,目标序列变得更短——这使得更长的更复杂的表格能够适应模型的环境窗口。这种变化也可能提高模型泛化到更大或更详细表格的能力。

让我们看看与第一次运行的评估损失比较:

图片

验证损失 4 | 图片由作者来自 wandb.ai 提供

初看,更高的评估损失可能看起来有些不合常理。然而,有一个明确的解释:包装 HTML 代码对模型来说很容易学习——因为它往往在许多训练样本中几乎完全相同。这些重复的标记减少了交叉熵损失,人为地降低了早期运行的平均损失。通过移除它们,模型现在完全专注于更具挑战性和可变性的表格内容——导致损失值更高但更有意义。

现在,让我们看看这个变化如何影响了 HTML 相似度 指标:

图片

在这次第一次测试中,我们没有观察到使用这种新输出格式的 显著提升或下降。可能模型需要 更多的训练轮数或 更大的训练样本才能完全适应这种新格式。另一个想法是 更新提示——这样从第一步开始,模型就明白它应该只专注于表格内容,而不是必须通过训练来推断这种行为。这将在接下来的实验中探讨。

第五 / 第六次运行:lang_table_unslothvision_table_unsloth

在这些实验中,我探索了以下组件:

图片

在这一点上,我发现了一个有希望的 Unsloth 框架——它声称可以提供 2 倍的训练速度以及高达 70% 的内存使用降低。当然,我想测试它是否可以加速我的工作流程。

我最初的想法是利用改进的内存处理来运行更长的序列(max_seq_length=2048),但在我这个例子中,这很快导致了 内存不足 (OOM) 错误——所以我恢复了之前的配置。

然而,训练速度的提高是无可否认的——几乎 4 倍 快于我之前的运行:

图片

每秒验证样本数 2 | 图片由作者来自 wandb.ai 提供

不幸的是,这明显是以损失性能为代价的 成本

图片

验证损失 5 | 图片由作者来自 wandb.ai 提供

由于质量明显下降,我暂停了实验以进行进一步调查——特别是看看 Unsloth 是否允许我训练视觉层,这是其宣传的优势之一。然而,我遇到了与 HuggingFace Transformers 完全相同的行为——视觉层没有实际的学习。

考虑到这些结果,我决定 将 Unsloth 保留在这个项目中并继续使用 HuggingFace Transformers,这在之前的运行中已经显示出更可靠的学习能力。

第七次运行:lang_table_only_2

这里是这个运行的新参数:

图片

回到之前的配置,我想分析一个 更大的虚拟批量大小(通过更高的 gradient_accumulation_steps)的影响。

结果很有希望——评估损失变得更加平滑,趋势更接近原始的lang_only运行,尽管模型现在只预测表格内容

图片

验证损失 6 | 图片由作者来自 wandb.ai

基于这个积极的结果,我决定将这个gradient_accumulation_steps=8设置用于最终实验。

HTML 相似度上评估这个模型导致了一个小但有意义的好转——最终与 Pix2Struct 达到了平衡:

图片

自然地,目标不仅仅是匹配 Pix2Struct——而是要超越它。还有两个重要的杠杆需要探索:数据集大小和提示。

第八次运行:lang_table_only_3

这次运行的更新参数为:

图片

我不小心在这个运行中将gradient_accumulation_steps回退到4,直到训练接近完成才意识到这一点——但实际上这给了我额外的机会来观察它对学习的影响。

这里的主要目标是将训练规模翻倍(到 10K 张图片)并测试更新的、更清晰的提示格式。不幸的是,一个随机的CUDA 错误导致训练在80%完成时停止——但即便如此,改进是明显的:

图片

验证损失 7 | 图片由作者来自 wandb.ai

如预期,由于虚拟批量大小的减小,一些平滑度丢失了,但新的提示证明非常有效——显著提升了模型的学习能力。

这为使用这个改进的提示、10K 个训练样本,并将gradient_accumulation_steps恢复到 8 的最终实验做好了完美的准备。

最终运行:lang_table_only_4

最终的参数集为:

图片

这次最终运行的评估损失:

图片

验证损失 7 | 图片由作者来自 wandb.ai

如预期,将gradient_accumulation_steps恢复到 8 平滑了损失曲线,减少了峰值,并实现了略低的总体损失值。在 10K 张图片上进行完整周期的训练,这使得它成为所有实验中表现最好的模型。

现在,让我们看看在HTML 相似度指标上的最终结果:

图片

最终 HTML 相似度结果 | 图片由作者来自 matplotlib

这个项目的目标已经实现——经过微调的模型现在在这个任务上超过了两个参考模型。回顾原始的 Granite-Vision 2B,LoRA 微调将性能提升到0.77,实现了+21 个百分点的提升——所有这些都在8 小时内,在一个消费级 GPU上完成。

定性结果

为了更好地说明模型通过微调的改进程度,让我们看看一个具体的例子:图像 ID 618932

图片

PubTabNet 评估样本 ID 618932 | 图像来自PMC

这个表格尤其棘手——在Kappa列下有子标题Present studyKing et al. 2001)。这些复杂的布局通常挑战通用的 VLMs,尤其是在训练期间没有接触到足够的类似示例时。模型通常可以理解这些子标题并回答关于它们的问题,但生成完整的 HTML 表格结构通常需要进一步的提示调整和专门的微调。

让我们先看看一个基础的、未经微调的 Granite-Vision 2B 模型在这个任务上的表现。

基准:原始 Granite-Vision 2B

该模型可以正确地根据表格回答问题:

prompt='What is the Kappa value for the question "Do you communicate with this power?" in the present study?'
res = predict(sample['image'], prompt=prompt)
print(res)

Out[1]:

74

然而,当要求生成完整的 HTML 表格时,模型遇到了困难:

prompt = "Convert table to HTML (<table> ... </table>)"
html = predict(sample['image'], prompt=prompt)
html = '<table>' + html + '</table>' if '<table>' not in html else html
display(HTML(html))

Out[2]:

以及这次尝试的HTML 相似度指标:

Style similarity: 1.0000
Structural similarity: 0.4091
Lev-Edit Distance: 0.1434
Final HTML Similarity Score: 0.3619

微调模型:lang_table_only_4

现在,让我们使用微调后的模型尝试完全相同的测试:

from src.models.granite_vision.transformers_library import LLM as granite_vision

model = granite_vision(
    model_path,
    adapter='lang_table_only_4'
)

Out[4]:

Model loaded
Adapter 'lang_table_only_4' loaded
Adapter 'lang_table_only_4' merged
Using cuda: NVIDIA GeForce RTX 4070 Ti SUPER

以及相同的预测提示:

prompt = "Convert table to HTML (<table> ... </table>)"
html = model.predict(sample['image'], max_new_tokens=1024, query=prompt)
display(HTML(html))

Out[5]:

微调后的模型现在生成的输出与真实值非常接近,正确地捕捉到了表格结构和子标题——这是基础模型难以做到的。

最终的HTML 相似度指标:

Style similarity: 1.0000
Structural similarity: 0.9231
Lev-Edit Distance: 1.0000
Final HTML Similarity Score: 0.9615

这个例子清楚地显示了明显的量化改进:在复杂表格结构上的得分从 0.36 提高到 0.96——证实了在这个专业任务上进行微调可以显著提高模型的能力。

推理速度

使用较小模型的一个主要优势——除了能够在消费级硬件上进行微调的能力之外——是推理速度。即使较大的模型提供具有竞争力的性能,延迟和吞吐量仍然是关键因素,尤其是在生产环境中。

让我们比较不同模型的推理速度:

推理速度 M | 图像由作者来自 matplotlib

如图中所示,Pix2Struct是迄今为止最快的模型。对于某些用例——例如批量处理数千份文档进行表格提取——这种速度优势可以转化为显著的时间节省和更低的计算成本。

然而,当需要处理的文档数量不是很大时,微调后的 Granite-Vision 2B在准确性上具有优势,推理速度合理,无需极端庞大的计算基础设施。

结论

这个项目证明了使用基于LoRA 的微调针对性任务(表格提取→HTML),一个小型视觉语言模型(Granite-Vision 2B)可以超越许多更大的模型——甚至 Meta 的 90B LLaMA Vision——同时只需要一个消费级 GPU和不到一天的训练时间。

一些关键要点:

  • 小型、专业化的模型很重要——你并不总是需要 70B+的模型来解决特定问题并获得高精度。

  • 参数高效微调(LoRA)是一个变革性的技术:适应大型基础模型对大多数从业者来说变得可行。

  • 提示设计和技术目标对结果有重大影响——小的改动(如切换到lang_table_only或细化提示)直接影响了性能。

  • 拥有一个自定义指标(HTML 相似度)对于跟踪超越通用文本指标的有意义进展至关重要。

  • 较小的模型不仅训练更快,而且推理更快——非常适合高吞吐量的生产流程。

最后——也许是最重要的——这种类型的实验表明,即使硬件有限,你也可以快速行动并迭代。微调强大的开源模型并将它们适应现实世界任务不再仅限于大型实验室。

我希望这能鼓励其他 AI 工程师在自己的项目和解决方案中尝试小型 VLM 和微调技术——并看到即使没有庞大的计算预算,也能取得强大的结果!

接下来是什么?

肯定有一些有趣的后续想法可以探索:

  • 提示工程优化:在撰写此博客的最终测试中显示,将提示分为系统消息(定义行为)和用户消息(提供任务指令)显著提高了基础模型的表现。在微调过程中应用此策略可以进一步提高模型持续生成准确 HTML 的能力。这将在即将进行的实验中得到测试。

  • 训练视觉层:目前,只有语言层被微调,因为仅通过文本损失训练视觉层已被证明无效。一种更先进的方法可能包括添加辅助视觉损失——例如,在视觉输出和 HTML 结构之间进行对比学习——以更好地适应表格提取任务的视觉骨干。

  • 提高泛化能力:当前模型是在单个数据集上微调的。将训练扩展到包括更多样化的文档布局、表格样式和有噪声的 OCR 场景可以提高鲁棒性和对现实世界数据的迁移性。

链接


如果你喜欢这篇帖子,请随时联系或分享你自己的实验!

posted @ 2026-03-28 09:35  绝不原创的飞龙  阅读(1)  评论(0)    收藏  举报