NTU ML2023Spring Part3.13 network compression

加了随机旋转的 data augmentation:

train_tfm = transforms.Compose([
    transforms.RandomRotation(degrees=(-30, 30)),
    # ...
])

加了几层 depthwise and pointwise convolution:

# Define your student network here. You have to copy-paste this code block to HW13 GradeScope before deadline.
# We will use your student network definition to evaluate your results(including the total parameter amount).

# Example implementation of Depthwise and Pointwise Convolution 
def dwpw_conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, groups=in_channels), #depthwise convolution
        nn.Conv2d(in_channels, out_channels, 1), # pointwise convolution
    )

class StudentNet(nn.Module):
    def __init__(self):
      super().__init__()

      # ---------- TODO ----------
      # Modify your model architecture

      self.cnn = nn.Sequential(
        nn.Conv2d(3, 4, 3), 
        nn.BatchNorm2d(4),
        nn.ReLU(),    

        nn.Conv2d(4, 16, 3), 
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        
        dwpw_conv(16, 64, 3), 
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),

        dwpw_conv(64, 64, 3), 
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.Conv2d(64, 64, 3), 
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        dwpw_conv(64, 64, 3), 
        nn.BatchNorm2d(64),
        nn.ReLU(),
        
        dwpw_conv(64, 84, 3), 
        nn.BatchNorm2d(84),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        
        # Here we adopt Global Average Pooling for various input size.
        nn.AdaptiveAvgPool2d((1, 1)),
      )
      self.fc = nn.Sequential(
        nn.Linear(84, 11),
      )
      
    def forward(self, x):
      out = self.cnn(x)
      out = out.view(out.size()[0], -1)
      return self.fc(out)

算下来参数量 56235,刚好在 60000 以内.

实现了 KL Divergence(然而是错的):

def loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=1.0):
    # ------------TODO-------------
    # Refer to the above formula and finish the loss function for knowkedge distillation using KL divergence loss and CE loss.
    # If you have no idea, please take a look at the provided useful link above.
    p = (student_logits / temperature).softmax(dim=0)
    q = (teacher_logits / temperature).softmax(dim=0)
    return alpha * temperature ** 2 * torch.nn.KLDivLoss(reduction = "batchmean")(p, q)
    + (1-alpha) * torch.nn.CrossEntropyLoss()(student_logits, teacher_logits)

开 50 个 epoch 信仰跑一发.结果只有 0.33 的成绩,怎么会逝呢?

原来 KL Divergence 写错了.于是直接机器学习机器学习一手:


def loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=1.0):
    p = (student_logits / temperature).softmax(dim=1)
    q = (teacher_logits / temperature).softmax(dim=1)
    kl_loss = nn.KLDivLoss(reduction="batchmean")(p.log(), q)
    ce_loss = nn.CrossEntropyLoss()(student_logits, labels)
    return alpha * temperature ** 2 * kl_loss + (1 - alpha) * ce_loss

顺便改了改神经网络的架构.

class StudentNet(nn.Module):
    def __init__(self):
      super().__init__()

      # ---------- TODO ----------
      # Modify your model architecture

      self.cnn = nn.Sequential(
        nn.Conv2d(3, 4, 3), 
        nn.BatchNorm2d(4),
        nn.ReLU(),    

        nn.Conv2d(4, 16, 3), 
        nn.BatchNorm2d(16),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        
        dwpw_conv(16, 64, 3), 
        nn.BatchNorm2d(64),
        nn.ReLU(),

        dwpw_conv(64, 32, 3), 
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        
        nn.Conv2d(32, 32, 3), 
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
          
        dwpw_conv(32, 128, 3), 
        nn.BatchNorm2d(128),
        nn.ReLU(),

        dwpw_conv(128, 128, 3), 
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        
        dwpw_conv(128, 84, 3), 
        nn.BatchNorm2d(84),
        nn.ReLU(),
        
        # Here we adopt Global Average Pooling for various input size.
        nn.AdaptiveAvgPool2d((1, 1)),
      )
      self.fc = nn.Sequential(
        nn.Linear(84, 11),
      )
      
    def forward(self, x):
      out = self.cnn(x)
      out = out.view(out.size()[0], -1)
      return self.fc(out)

epoch 1 就有 0.35 的正确率了.我真傻,真的.

最后跑完 50 个 epoch 后得分 0.692,刚好卡过了 medium baseline.

posted @ 2025-06-19 19:35  383494  阅读(8)  评论(0)    收藏  举报