NTU ML2023Spring Part3.15 meta learning
要实现 metasolver 里的内层和外层 update.
def MetaSolver(
model,
optimizer,
x,
n_way,
k_shot,
q_query,
loss_fn,
inner_train_step=1,
inner_lr=0.4,
train=True,
return_labels=False
):
criterion, task_loss, task_acc = loss_fn, [], []
labels = []
for meta_batch in x:
# Get data
support_set = meta_batch[: n_way * k_shot]
query_set = meta_batch[n_way * k_shot :]
# Copy the params for inner loop
fast_weights = OrderedDict(model.named_parameters())
### ---------- INNER TRAIN LOOP ---------- ###
for inner_step in range(inner_train_step):
# Simply training
train_label = create_label(n_way, k_shot).to(device)
logits = model.functional_forward(support_set, fast_weights)
loss = criterion(logits, train_label)
# Inner gradients update! vvvvvvvvvvvvvvvvvvvv #
""" Inner Loop Update """
# implement FO-MAML
# TODO: Finish the inner loop update rule
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)
fast_weights = OrderedDict(
(name, param - inner_lr * grad)
for ((name, param), grad) in zip(fast_weights.items(), grads)
)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ #
### ---------- INNER VALID LOOP ---------- ###
if not return_labels:
""" training / validation """
val_label = create_label(n_way, q_query).to(device)
# Collect gradients for outer loop
logits = model.functional_forward(query_set, fast_weights)
loss = criterion(logits, val_label)
task_loss.append(loss)
task_acc.append(calculate_accuracy(logits, val_label))
else:
""" testing """
logits = model.functional_forward(query_set, fast_weights)
labels.extend(torch.argmax(logits, -1).cpu().numpy())
if return_labels:
return labels
# Update outer loop
model.train()
optimizer.zero_grad()
meta_batch_loss = torch.stack(task_loss).mean()
if train:
""" Outer Loop Update """
# TODO: Finish the outer loop update
meta_batch_loss.backward()
optimizer.step()
task_acc = np.mean(task_acc)
return meta_batch_loss, task_acc
内层使用机器学习机器学习(感谢 colab 提供的 gemini),外层就直接无脑 backward,step 即可.
跑得飞快,但交上去只有 0.62,甚至没过 simple baseline.
原来是 step5 里 solver = 'base' 忘改为 solve = 'meta' 了.改完重新跑了一发,得到 0.81 的成绩.过了 medium,但没过 strong.

浙公网安备 33010602011771号