init commit of samurai
This commit is contained in:
63
lib/utils/focal_loss.py
Normal file
63
lib/utils/focal_loss.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from abc import ABC
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FocalLoss(nn.Module, ABC):
|
||||
def __init__(self, alpha=2, beta=4):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, prediction, target):
|
||||
positive_index = target.eq(1).float()
|
||||
negative_index = target.lt(1).float()
|
||||
|
||||
negative_weights = torch.pow(1 - target, self.beta)
|
||||
# clamp min value is set to 1e-12 to maintain the numerical stability
|
||||
prediction = torch.clamp(prediction, 1e-12)
|
||||
|
||||
positive_loss = torch.log(prediction) * torch.pow(1 - prediction, self.alpha) * positive_index
|
||||
negative_loss = torch.log(1 - prediction) * torch.pow(prediction,
|
||||
self.alpha) * negative_weights * negative_index
|
||||
|
||||
num_positive = positive_index.float().sum()
|
||||
positive_loss = positive_loss.sum()
|
||||
negative_loss = negative_loss.sum()
|
||||
|
||||
if num_positive == 0:
|
||||
loss = -negative_loss
|
||||
else:
|
||||
loss = -(positive_loss + negative_loss) / num_positive
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LBHinge(nn.Module):
|
||||
"""Loss that uses a 'hinge' on the lower bound.
|
||||
This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is
|
||||
also smaller than that threshold.
|
||||
args:
|
||||
error_matric: What base loss to use (MSE by default).
|
||||
threshold: Threshold to use for the hinge.
|
||||
clip: Clip the loss if it is above this value.
|
||||
"""
|
||||
def __init__(self, error_metric=nn.MSELoss(), threshold=None, clip=None):
|
||||
super().__init__()
|
||||
self.error_metric = error_metric
|
||||
self.threshold = threshold if threshold is not None else -100
|
||||
self.clip = clip
|
||||
|
||||
def forward(self, prediction, label, target_bb=None):
|
||||
negative_mask = (label < self.threshold).float()
|
||||
positive_mask = (1.0 - negative_mask)
|
||||
|
||||
prediction = negative_mask * F.relu(prediction) + positive_mask * prediction
|
||||
|
||||
loss = self.error_metric(prediction, positive_mask * label)
|
||||
|
||||
if self.clip is not None:
|
||||
loss = torch.min(loss, torch.tensor([self.clip], device=loss.device))
|
||||
return loss
|
Reference in New Issue
Block a user