暑假gigapath模型蒸馏随笔1

和导师讨论暑假选题,最后在对病理模型gigapath进行模型量化(float16->float8)和模型蒸馏中选择了模型蒸馏

开始查怎么部署蒸馏脚本

一.首先关注我们的gigapath模型,他是由两个部分组成

1.tile_encoder

1)它是啥:

这个模型是一个小图片的切片编码器,数据来源是将一张大的wsi病理图像切片并进行编码,具体格式是

文件名

--文件.svs

----坐标.png(例如-1234x_5678y.png)

----坐标.png

----。。。。(几百张)

其中的坐标是图片的左上角点位于原图片的坐标

2)它是啥模型:

这个预训练的模型的架构是vit_giant_patch14_dinov2,它是一个vit模型(vision transformer),大约有4.22g

3)它做什么:

输入一个标准的图像输入(518x518)这个图像输入是wsi用gigapath项目中自带的图片切片脚本进行切片的产物

prov-gigapath/demo/2_tiling_demo.py at main · prov-gigapath/prov-gigapath

通过vision transformer将图像中的特征提取,编码成一个1536维的特征向量

在最终输出层中输出这个高维向量

4)总结

可以把这个模型看成一个病理专家,它每次只看一小块显微镜下的切片,然后把这个切片的所有能分析出来的信息迅速浓缩概括进一个高维向量里

2.silde_encoder

1)它是啥

这个模型是切片聚合器,用于聚合来自同一张 WSI 的所有 tile 特征向量,生成一个能够代表整张全切片图像的单一、最终的特征向量。

2)它是啥模型

它是一个transformer模型

add:我们可以把一张图片的上下左右位置关系,看成一个句子,离得近一点的组织块肯定和这个图片有更大的联系,就和transformer的注意力机制有很相似的原理,故采用transformer这个用于关注文本上下文的模型。

3)它做了什么

首先它需要每个小切片的位置信息,聚合这些位置信息,分析所有 tile 特征之间的关系,并且判断哪些 tiles 对于最终的诊断或分类更重要,将所有这些局部信息智能地融合成一个768维的全局切片级embending,用于下游的分析。

4)总结

可以把这个模型看成一个总指挥,它收集了整张图片所有员工对这个图片的分析,结合位置信息,对整个图片的信息进行浓缩整理总结,这个过程在学术上通常被称为多示例学习(Multiple Instance Learning, MIL)

二.了解了以上信息,选取蒸馏的学生模型

众所周知,模型蒸馏是由一个大的教师模型对学生模型进行教学,让学生模型模仿并拟合教师模型的输出,并且由于学生模型参数量小,能够很方便的在很多边缘设备进行部署运行,并且也可以很方便的对于某些特定任务进行微调处理。

对于tile_encoder:

由于我们知道这个是一个视觉的transformer模型,我们可以采取CNN或者是其他更小的vit模型来进行蒸馏,首先我们排除了CNN,优先选择300mb的vit_base来进行蒸馏,这是由于我们的数据量较小(500张左右的tcga病理图像切片),故我们采用对中间层进行学习的方法,这样可以在小批量的无标签数据中,尽可能的模仿教师模型的行为

总的loss计算公式如下

Loss_total = (lose1+lose2)/2

其中loss1是tile_encoder的loss,其表达式为

loss1 = beta * L_hidde(− beta) * L_slide

其中

L_slide: 最终幻片嵌入的loss

L_hidden:中间隐藏层的计算损失

beta为超参,即要从隐藏层中学多少东西

核心代码:

for i, (slide_tiles, slide_coords) in enumerate(progress_bar):
            if slide_tiles is None: continue
            slide_tiles = slide_tiles.to(device)
            slide_coords = slide_coords.to(device)
            
            with torch.cuda.amp.autocast():
                # --- Tile Batching Loop ---
                all_teacher_hidden_features, all_student_hidden_features = [], []
                for j in range(0, slide_tiles.size(0), config['tile_batch_size']):
                    tile_batch = slide_tiles[j:j+config['tile_batch_size']]
                    with torch.no_grad():
                        teacher_hidden_list = teacher_models['tile_backbone'](tile_batch)
                        teacher_hidden_features_batch = teacher_hidden_list[-1] # 取最后一个Block的输出
                    
                    student_hidden_list = student_models['tile_backbone'](tile_batch)
                    student_hidden_features_batch = student_hidden_list[-1] # 取最后一个Block的输出
                    
                    all_teacher_hidden_features.append(teacher_hidden_features_batch)
                    all_student_hidden_features.append(student_hidden_features_batch)
                
                teacher_hidden_features = torch.cat(all_teacher_hidden_features, dim=0)
                student_hidden_features = torch.cat(all_student_hidden_features, dim=0)
                
                # --- 核心蒸馏计算 ---
                # 1. 隐藏层损失 (L_hidden)
                #    注意:ViT的输出是 [B, N_patches+1, Dim],我们匹配所有token
                adapted_student_hidden = student_models['feature_adaptor'](student_hidden_features)
                loss_hidden = loss_fn(adapted_student_hidden, teacher_hidden_features)

                # 2. 最终Slide嵌入损失 (L_slide)
                #    首先需要通过head得到最终的tile特征
                teacher_final_tile_features = teacher_models['tile_head'](teacher_hidden_features[:, 0]) # 只取[CLS] token
                student_final_tile_features = student_models['tile_head'](student_hidden_features[:, 0])
                
                with torch.no_grad():
                    teacher_slide_embed = teacher_models['slide'](teacher_final_tile_features.unsqueeze(0), slide_coords.unsqueeze(0))
                student_slide_embed = student_models['slide'](student_final_tile_features.unsqueeze(0))
                loss_slide = loss_fn(student_slide_embed, teacher_slide_embed)

                # 3. 总损失
                total_loss = config['beta'] * loss_hidden + (1 - config['beta']) * loss_slide
                total_loss = total_loss / config['gradient_accumulation_steps']

            scaler.scale(total_loss).backward()

features_only=True:利用这个timm的强大功能,我们无需修改模型源码就能轻松获取中间层的特征图

对于decoder:

我们还是采用transformer,一个参数更小的transformer进行蒸馏

 

posted @ 2025-07-26 14:32  liujunxi  阅读(24)  评论(0)    收藏  举报