class GIoULoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, A, B):
num_bbox = A.size(0) * A.size(2)
ax, ay, ar, ab = A[:, 0], A[:, 1], A[:, 2], A[:, 3]
bx, by, br, bb = B[:, 0], B[:, 1], B[:, 2], B[:, 3]
xmax = torch.min(ar, br)
ymax = torch.min(ab, bb)
xmin = torch.max(ax, bx)
ymin = torch.max(ay, by)
cross_width = (xmax - xmin + 1).clamp(0)
cross_height = (ymax - ymin + 1).clamp(0)
cross = cross_width * cross_height
union = (ar - ax + 1) * (ab - ay + 1) + (br - bx + 1) * (bb - by + 1) - cross
iou = cross / union
cxmin = torch.min(ax, bx)
cymin = torch.min(ay, by)
cxmax = torch.max(ar, br)
cymax = torch.max(ab, bb)
c = (cxmax - cxmin + 1) * (cymax - cymin + 1)
return (1 - (iou - (c - union) / c)).sum() / num_bbox