| def weighted_loss(loss_func): |
| |
"""Create a weighted version of a given loss function. |
| |
|
| |
To use this decorator, the loss function must have the signature like |
| |
`loss_func(pred, target, **kwargs)`. The function only needs to compute |
| |
element-wise loss without any reduction. This decorator will add weight |
| |
and reduction arguments to the function. The decorated function will have |
| |
the signature like `loss_func(pred, target, weight=None, reduction='mean', |
| |
avg_factor=None, **kwargs)`. |
| |
|
| |
:Example: |
| |
|
| |
>>> import torch |
| |
>>> @weighted_loss |
| |
>>> def l1_loss(pred, target): |
| |
>>> return (pred - target).abs() |
| |
|
| |
>>> pred = torch.Tensor([0, 2, 3]) |
| |
>>> target = torch.Tensor([1, 1, 1]) |
| |
>>> weight = torch.Tensor([1, 0, 1]) |
| |
|
| |
>>> l1_loss(pred, target) |
| |
tensor(1.3333) |
| |
>>> l1_loss(pred, target, weight) |
| |
tensor(1.) |
| |
>>> l1_loss(pred, target, reduction='none') |
| |
tensor([1., 1., 2.]) |
| |
>>> l1_loss(pred, target, weight, avg_factor=2) |
| |
tensor(1.5000) |
| |
""" |
| |
|
| |
@functools.wraps(loss_func) |
| |
def wrapper(pred, |
| |
target, |
| |
weight=None, |
| |
reduction='mean', |
| |
avg_factor=None, |
| |
**kwargs): |
| |
# get element-wise loss |
| |
loss = loss_func(pred, target, **kwargs) |
| |
loss = weight_reduce_loss(loss, weight, reduction, avg_factor) |
| |
return loss |
| |
|
| |
return wrapper |