python 修饰器@

def weighted_loss(loss_func):
  """Create a weighted version of a given loss function.
   
  To use this decorator, the loss function must have the signature like
  `loss_func(pred, target, **kwargs)`. The function only needs to compute
  element-wise loss without any reduction. This decorator will add weight
  and reduction arguments to the function. The decorated function will have
  the signature like `loss_func(pred, target, weight=None, reduction='mean',
  avg_factor=None, **kwargs)`.
   
  :Example:
   
  >>> import torch
  >>> @weighted_loss
  >>> def l1_loss(pred, target):
  >>> return (pred - target).abs()
   
  >>> pred = torch.Tensor([0, 2, 3])
  >>> target = torch.Tensor([1, 1, 1])
  >>> weight = torch.Tensor([1, 0, 1])
   
  >>> l1_loss(pred, target)
  tensor(1.3333)
  >>> l1_loss(pred, target, weight)
  tensor(1.)
  >>> l1_loss(pred, target, reduction='none')
  tensor([1., 1., 2.])
  >>> l1_loss(pred, target, weight, avg_factor=2)
  tensor(1.5000)
  """
   
  @functools.wraps(loss_func)
  def wrapper(pred,
  target,
  weight=None,
  reduction='mean',
  avg_factor=None,
  **kwargs):
  # get element-wise loss
  loss = loss_func(pred, target, **kwargs)
  loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
  return loss
   
  return wrapper
posted @ 2022-05-12 09:08  lisong333  阅读(26)  评论(0)    收藏  举报