Swin-Transformer 源码学习与使用手册

拜读了VIT以及TNT以及Swin-Transformer 的论文,结合B站up主的分析,

预感Swin-Transformer具有ResNet似的跨里程碑式的意义,

因此学习Swin-Transformer源码及其使用,记录如下。


 

 

1 Run 

Swin-T
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main.py --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path data --batch-size 128

Swin-S
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main.py --cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path data --batch-size 128

Swin-B
python -m torch.distributed.launch --nproc_per_node 2 --master_port 12345 main.py --cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path data --batch-size 64 --accumulation-steps 2

上述命令中,
--nproc_per_node 为指定gpu数量,
--master_port 为指定端口号
--cfg 为指定配置yaml文件,该文件中参数较为重要,如下图所示

 --data_path 为数据集路径,需要将数据集命名为 train 和 val

--batch_size 为批次大小


 

2 config.py

 

_C.MODEL.RESUME 该参数比较重要,为预训练模型的路径
预训练模型的获取在github主页,https://github.com/microsoft/Swin-Transformer,如下图所示

_C.MODEL.NUM_CLASSES 为分类类别数 
_C.MODEL.TYPE = 'swin' 模型type,这里有swin 和 swin2




3 Model

 上图为Swin Transformer的网络结构图,可知,沿用了VIT的Patch思想,对图像进行分割裁剪为Patch,注意是采用卷积实现的Patch,源码实现如下图所示。

 同时,Swin Transformer有两种,W-MSA和SW-MSA 与W-MSA相比,SW-MSA中的S为Shifted,为滑动窗口Shifted Window


4 迁移学习及改进策略

4.1 迁移学习

基于Transformer的模型依赖大数据集,但是实际应用中很难采集到大数据,因此,应用迁移学习解决这一问题,在Swin中,微软公司开源再imageNet上预训练的权重,

通过修改源码来完成基于预训练权重的迁移学习,整体过程如下:

(1)在config.py中修改预训练权重

(2)在utils中查看load_checkpoing方法,该方法为加载预训练权重,添加以下代码,该代码实际上是重新搭建了最后的分类层,

首先判断分类层shape与分类数是否相等,如果不等,直接new一个空的分类层,输入向量的长度为768

if checkpoint['model']['head.weight'].shape[0] != config.MODEL.NUM_CLASSES:
    checkpoint['model']['head.weight'] = torch.nn.Parameter(
        torch.nn.init.xavier_uniform(torch.empty(config.MODEL.NUM_CLASSES, 768)))
    checkpoint['model']['head.bias'] = torch.nn.Parameter(torch.randn(config.MODEL.NUM_CLASSES))

(3)运行main.py即可开始训练迁移学习后的模型

4.2 池化策略改进-引入softpool

根据官方源码修改softpool
原始代码为:
def soft_pool1d(x, kernel_size=2, stride=None, force_inplace=False):
    if x.is_cuda and not force_inplace:
        x = CUDA_SOFTPOOL1d.apply(x, kernel_size, stride)
        # Replace `NaN's if found
        if torch.isnan(x).any():
            return torch.nan_to_num(x)
        return x
    kernel_size = _single(kernel_size)
    if stride is None:
        stride = kernel_size
    else:
        stride = _single(stride)
    # Get input sizes
    _, c, d = x.size()
    # Create exponential mask (should be similar to max-like pooling)
    e_x = torch.sum(torch.exp(x),dim=1,keepdim=True)
    e_x = torch.clamp(e_x , float(0), float('inf'))
    # Apply mask to input and pool and calculate the exponential sum
    # Tensor: [b x c x d] -> [b x c x d']
    x = F.avg_pool1d(x.mul(e_x), kernel_size, stride=stride).mul_(sum(kernel_size)).div_(F.avg_pool1d(e_x, kernel_size, stride=stride).mul_(sum(kernel_size)))
    return torch.clamp(x , float(0), float('inf'))

上述代码为原始softpool代码,但是Swin原始为自适应平均池化,输出降维到1,上述代码无法直接引入到Swin中,对代码进行修改如下:


class SoftPool1d(torch.nn.Module):
def __init__(self):
super(SoftPool1d, self).__init__()
self.kernel_size = 2

def forward(self, input_tensor):
_, c, d = input_tensor.size()
# Create exponential mask (should be similar to max-like pooling)
e_x = torch.sum(torch.exp(input_tensor), dim=1, keepdim=True)
e_x = torch.clamp(e_x, float(0), float('inf'))
# Apply mask to input and pool and calculate the exponential sum
# Tensor: [b x c x d] -> [b x c x d']
kernel_size = _single(self.kernel_size)
x = F.adaptive_avg_pool1d(input_tensor.mul(e_x), output_size=1).mul_(sum(kernel_size)).div_(
F.adaptive_avg_pool1d(e_x, output_size=1).mul_(sum(kernel_size)))

return torch.clamp(x, float(0), float('inf'))
 

 同时对args命令行参数进行修改,引入pool,来区分pool_type,在构建池化层时进行判断

       if pool_type == 'avgpool':
            self.avgpool = nn.AdaptiveAvgPool1d(1)
        elif pool_type == 'softpool':
            self.softpool = SoftPool1d()

        if self.pool_type == 'avgpool':
            x = self.avgpool(x.transpose(1, 2))  # B C 1
            print(x)
        elif self.pool_type == 'softpool':
            x = self.softpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)

4.3 swin改进实验记录

序号 swin预训练权重 是否固定 convnext预训练权重 是否固定 swin pool_type 特征融合 acc1_val acc1_test save_path
4.3.1 imageNet22k √    imageNet avg_pool cat  91.7 87.667  swinConNext_swin_frozen
4.3.2   imageNet22k ×  imageNet  √  avg_pool  cat+softmax  93.85  89.000  swinConNext_swin_softmax
4.3.3   imageNet22k  √  imageNet   √  avg_pool   cat+softmax  91.1538    swinConNext_swin_frozen_softmax  
                   
                   
                   
                   

 

4.3.1

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path data --batch-size 256 --output swinConNext_swin_frozen --resume pth/swin_tiny_patch4_window7_224_22k.pth --swin_frozen True

 

说明:convnext使用预训练权重,固定权重,作为特征提取器,将输出的向量拉平进行cat操作,对swin2使用imageNet22k预训练权重,不固定权重。

 

class ConvNext(nn.Module):
    def __init__(self):
        super(ConvNext, self).__init__()
        self.convnext = self.get_model()

    def forward(self, x):
        x = self.convnext(x)
        return x

    def get_model(self):
        model = convnext_base(pretrained=True)
        layer_name = 'avgpool'
        previous_model = nn.Sequential()
        for name, module in model.named_children():
            if name == layer_name:
                previous_model.add_module(name, module)
                break
            previous_model.add_module(name, module)
        for param in previous_model.parameters():
            param.requires_grad = False
        return previous_model


class SwinConvNext(nn.Module):
    def __init__(self, swin, swin_frozen=False, softmax_weights=False):
        super(SwinConvNext, self).__init__()
        self.swin = self.get_swin(swin, swin_frozen)
        self.convnext = ConvNext()
        self.fc = nn.Linear(1792, 3)
        self.softmax_weights = softmax_weights

    def forward(self, x):
        swin_output = self.swin(x.clone())
        convnext_output = self.convnext(x)
        swin_output = torch.flatten(swin_output, 1)
        convnext_output = torch.flatten(convnext_output, 1)
        x = torch.cat((swin_output, convnext_output), dim=1)
        if self.softmax_weights:
            weights = torch.softmax(x, dim=1)
            x = x * weights
        x = self.fc(x)
        return x

    def get_swin(self, swin, swin_frozen):
        if swin_frozen:
            for param in swin.parameters():
                param.requires_grad = False
        return swin

评估代码

 

python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py --eval --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --resume swinConNext_swin_frozen/swin_tiny_patch4_window7_224/default/ckpt_epoch_13.pth --data-path ../CUMT-BelT

4.3.2

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path data --batch-size 128 --output swinConNext_swin_softmax --resume pth/swin_tiny_patch4_window7_224_22k.pth --softmax_weights True

评估代码

python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py --eval --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --resume swinConNext_swin_softmax/swin_tiny_patch4_window7_224/default/ckpt_epoch_12.pth --data-path ../CUMT-BelT --softmax_weights True

 

4.3.3

nohup python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path data --batch-size 256 --output swinConNext_swin_frozen_softmax  --resume pth/swin_tiny_patch4_window7_224_22k.pth --softmax_weights True --swin_frozen True > test.log &

 

评估代码

python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py --eval --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --resume swinConNext_swin_frozen_softmax/swin_tiny_patch4_window7_224/default/ckpt_epoch_99.pth --data-path ../CUMT-BelT --softmax_weights True

posted @ 2023-09-06 22:10  spx88  阅读(247)  评论(0编辑  收藏  举报