如何冻结llava的参数,在训练时不动

思路很简单:把 LLaVA 的所有参数 requires_grad=False,只训练你自己的模块(比如 mask 头)。下面给出几种常见写法,任选其一。

  1. 冻结整个 LLaVA(vision + projector + language model)
def freeze_llava(llava):
    for p in llava.parameters():
        p.requires_grad = False

# 你的封装里(以 XVQAModel 为例)
model = XVQAModel(llava_model=llava, ...)
freeze_llava(model.llava)

# 只把需要训练的模块交给优化器(例如 mask 头)
optimizer = torch.optim.AdamW(
    params=[p for p in model.mask_head_deocder.parameters() if p.requires_grad],
    lr=1e-4
)

记得优化器只传入可训练参数,否则会把被冻结的也塞进去(虽然不会更新,但浪费时间/显存)。

  1. 选择性冻结(只留 projector 或某些层可训)

只训练 mm_projector(常见微调方式):

def freeze_all_but_projector(llava):
    for p in llava.parameters():
        p.requires_grad = False
    for p in llava.mm_projector.parameters():
        p.requires_grad = True

freeze_all_but_projector(model.llava)
optimizer = torch.optim.AdamW(
    list(model.mask_head_deocder.parameters()) + list(model.llava.mm_projector.parameters()),
    lr=1e-4
)

只训练语言模型最后 N 层(举例 N=2):

def freeze_all_but_last_n_transformer_blocks(llava, n=2):
    for p in llava.parameters():
        p.requires_grad = False
    # LLaVA 的 LLM 一般在 llava.model.layers
    for blk in llava.model.layers[-n:]:
        for p in blk.parameters():
            p.requires_grad = True

freeze_all_but_last_n_transformer_blocks(model.llava, n=2)

只训练视觉塔(不常见):

def freeze_all_but_vision(llava):
    for p in llava.parameters():
        p.requires_grad = False
    for p in llava.vision_tower.parameters():
        p.requires_grad = True
  1. 冻结后的小检查
def count_trainable_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print("trainable in XVQAModel:", count_trainable_params(model))
print("trainable in LLaVA:", count_trainable_params(model.llava))
print("trainable in mask head:", count_trainable_params(model.mask_head_deocder))
posted @ 2025-11-07 02:47  zae  阅读(8)  评论(0)    收藏  举报