pifpaf的loss生成过程

pifpaf的loss是通过loss = network.losses.Factory().factory(datamodule.head_metas生成的。
trainer的loop循环,然后train函数,然后进入train_batch这里求了loss。

factory中

if self.auto_tune_mtl:
loss = MultiHeadLossAutoTuneKendall(
losses, component_lambdas, sparse_task_parameters=sparse_task_parameters)
elif self.auto_tune_mtl_variance:
loss = MultiHeadLossAutoTuneVariance(
losses, component_lambdas, sparse_task_parameters=sparse_task_parameters)
else:
loss = MultiHeadLoss(losses, component_lambdas)

return loss

loss由这三个类生成。

在这三个类中生成损失函数值的过程。

losses是多个损失函数类。

flat_head_losses = [ll
for l, f, t in zip(self.losses, head_fields, head_targets)
for ll in l(f, t)]

posted @ 2023-03-07 11:01  祥瑞哈哈哈  阅读(26)  评论(0)    收藏  举报