pifpaf的loss生成过程
pifpaf的loss是通过loss = network.losses.Factory().factory(datamodule.head_metas生成的。
trainer的loop循环,然后train函数,然后进入train_batch这里求了loss。
factory中
if self.auto_tune_mtl:
loss = MultiHeadLossAutoTuneKendall(
losses, component_lambdas, sparse_task_parameters=sparse_task_parameters)
elif self.auto_tune_mtl_variance:
loss = MultiHeadLossAutoTuneVariance(
losses, component_lambdas, sparse_task_parameters=sparse_task_parameters)
else:
loss = MultiHeadLoss(losses, component_lambdas)
return loss
loss由这三个类生成。
在这三个类中生成损失函数值的过程。
losses是多个损失函数类。
flat_head_losses = [ll
for l, f, t in zip(self.losses, head_fields, head_targets)
for ll in l(f, t)]

浙公网安备 33010602011771号