MAML算法概述

MAML算法概述

什么是MAML

1. 论文地址:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

2. 要解决的问题

  • 小样本问题
  • 模型收敛过慢

3. 算法描述

​ MAML期望通过训练一组初始化参数,使得模型透过训练出的初始化参数,未来在少量样本基础上实现快速收敛。该初始化参数 在训练集上未必是最优解,但可以通过训练出的参数在新的任务上快速收敛,找到最优解。

4. V.S. Pre-train

  • Pre-train:训练集上的全局最优参数,但放到测试集上未必可以训练出全局最优,可能只会找到局部最优点。

  • MAML:在训练集和测试集上未必全局最优参数,但通过少量迭代,便可收敛到全局最优。

算法描述

  1. 随机初始化一个权重θ
  2. Setp3 ~ Step10:一个epoch
  3. 随机采样一个batch的Task
  4. 遍历所有Task
  5. 从Support Set中取出一个batch的Task中的Label和Image
  6. Setp6 ~ Step7:前向传播,计算梯度后反向传播,更新θ′这个权重
  7. 从Query Set中取出所有Task前向传播,但不更新模型
  8. Step10:所有Task结束后,计算Loss,计算梯度,更新θ的权重

核心代码

for epoch in range(args.epoch//10000):
    # fetch meta_batchsz num of episode each time
    db = DataLoader(mini, args.task_num, shuffle=True, num_workers=0, pin_memory=True)

    for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):

        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

        accs = maml(x_spt, y_spt, x_qry, y_qry)

        if step % 30 == 0:
            print('step:', step, '\ttraining acc:', accs)

        if step % 500 == 0:  # evaluation
            db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=0, pin_memory=True)
            accs_all_test = []

            for x_spt, y_spt, x_qry, y_qry in db_test:
                x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                             x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                accs_all_test.append(accs)

            # [b, update_step+1]
            accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
posted @ 2023-10-18 21:38  HoroSherry  阅读(377)  评论(0)    收藏  举报