Files
Grounded-SAM-2/lib/utils/focal_loss.py

63 lines
2.3 KiB
Python
Raw Normal View History

2024-11-19 22:12:54 -08:00
from abc import ABC
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module, ABC):
def __init__(self, alpha=2, beta=4):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.beta = beta
def forward(self, prediction, target):
positive_index = target.eq(1).float()
negative_index = target.lt(1).float()
negative_weights = torch.pow(1 - target, self.beta)
# clamp min value is set to 1e-12 to maintain the numerical stability
prediction = torch.clamp(prediction, 1e-12)
positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index
negative_loss = torch.log(1 - prediction) * torch.pow(prediction,
self.alpha) * negative_weights * negative_index
num_positive = positive_index.float().sum()
positive_loss = positive_loss.sum()
negative_loss = negative_loss.sum()
if num_positive == 0:
loss = -negative_loss
else:
loss = -(positive_loss + negative_loss) / num_positive
return loss
class LBHinge(nn.Module):
"""Loss that uses a 'hinge' on the lower bound.
This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is
also smaller than that threshold.
args:
error_matric: What base loss to use (MSE by default).
threshold: Threshold to use for the hinge.
clip: Clip the loss if it is above this value.
"""
def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None):
super().__init__()
self.error_metric = error_metric
self.threshold = threshold if threshold is not None else -100
self.clip = clip
def forward(self, prediction, label, target_bb=None):
negative_mask = (label < self.threshold).float()
positive_mask = (1.0 - negative_mask)
prediction = negative_mask * F.relu(prediction) + positive_mask * prediction
loss = self.error_metric(prediction, positive_mask * label)
if self.clip is not None:
loss = torch.min(loss, torch.tensor([self.clip], device=loss.device))
return loss