308 lines
12 KiB
Python
308 lines
12 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from collections import defaultdict
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from training.trainer import CORE_LOSS_KEY
|
|
|
|
from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
|
|
|
|
|
|
def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
|
|
"""
|
|
Compute the DICE loss, similar to generalized IOU for masks
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
num_objects: Number of objects in the batch
|
|
loss_on_multimask: True if multimask prediction is enabled
|
|
Returns:
|
|
Dice loss tensor
|
|
"""
|
|
inputs = inputs.sigmoid()
|
|
if loss_on_multimask:
|
|
# inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
|
|
assert inputs.dim() == 4 and targets.dim() == 4
|
|
# flatten spatial dimension while keeping multimask channel dimension
|
|
inputs = inputs.flatten(2)
|
|
targets = targets.flatten(2)
|
|
numerator = 2 * (inputs * targets).sum(-1)
|
|
else:
|
|
inputs = inputs.flatten(1)
|
|
numerator = 2 * (inputs * targets).sum(1)
|
|
denominator = inputs.sum(-1) + targets.sum(-1)
|
|
loss = 1 - (numerator + 1) / (denominator + 1)
|
|
if loss_on_multimask:
|
|
return loss / num_objects
|
|
return loss.sum() / num_objects
|
|
|
|
|
|
def sigmoid_focal_loss(
|
|
inputs,
|
|
targets,
|
|
num_objects,
|
|
alpha: float = 0.25,
|
|
gamma: float = 2,
|
|
loss_on_multimask=False,
|
|
):
|
|
"""
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
num_objects: Number of objects in the batch
|
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
|
positive vs negative examples. Default = -1 (no weighting).
|
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
|
balance easy vs hard examples.
|
|
loss_on_multimask: True if multimask prediction is enabled
|
|
Returns:
|
|
focal loss tensor
|
|
"""
|
|
prob = inputs.sigmoid()
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
p_t = prob * targets + (1 - prob) * (1 - targets)
|
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
|
|
|
if alpha >= 0:
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
loss = alpha_t * loss
|
|
|
|
if loss_on_multimask:
|
|
# loss is [N, M, H, W] where M corresponds to multiple predicted masks
|
|
assert loss.dim() == 4
|
|
return loss.flatten(2).mean(-1) / num_objects # average over spatial dims
|
|
return loss.mean(1).sum() / num_objects
|
|
|
|
|
|
def iou_loss(
|
|
inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
|
|
):
|
|
"""
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
pred_ious: A float tensor containing the predicted IoUs scores per mask
|
|
num_objects: Number of objects in the batch
|
|
loss_on_multimask: True if multimask prediction is enabled
|
|
use_l1_loss: Whether to use L1 loss is used instead of MSE loss
|
|
Returns:
|
|
IoU loss tensor
|
|
"""
|
|
assert inputs.dim() == 4 and targets.dim() == 4
|
|
pred_mask = inputs.flatten(2) > 0
|
|
gt_mask = targets.flatten(2) > 0
|
|
area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
|
|
area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
|
|
actual_ious = area_i / torch.clamp(area_u, min=1.0)
|
|
|
|
if use_l1_loss:
|
|
loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
|
|
else:
|
|
loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
|
|
if loss_on_multimask:
|
|
return loss / num_objects
|
|
return loss.sum() / num_objects
|
|
|
|
|
|
class MultiStepMultiMasksAndIous(nn.Module):
|
|
def __init__(
|
|
self,
|
|
weight_dict,
|
|
focal_alpha=0.25,
|
|
focal_gamma=2,
|
|
supervise_all_iou=False,
|
|
iou_use_l1_loss=False,
|
|
pred_obj_scores=False,
|
|
focal_gamma_obj_score=0.0,
|
|
focal_alpha_obj_score=-1,
|
|
):
|
|
"""
|
|
This class computes the multi-step multi-mask and IoU losses.
|
|
Args:
|
|
weight_dict: dict containing weights for focal, dice, iou losses
|
|
focal_alpha: alpha for sigmoid focal loss
|
|
focal_gamma: gamma for sigmoid focal loss
|
|
supervise_all_iou: if True, back-prop iou losses for all predicted masks
|
|
iou_use_l1_loss: use L1 loss instead of MSE loss for iou
|
|
pred_obj_scores: if True, compute loss for object scores
|
|
focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
|
|
focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
|
|
"""
|
|
|
|
super().__init__()
|
|
self.weight_dict = weight_dict
|
|
self.focal_alpha = focal_alpha
|
|
self.focal_gamma = focal_gamma
|
|
assert "loss_mask" in self.weight_dict
|
|
assert "loss_dice" in self.weight_dict
|
|
assert "loss_iou" in self.weight_dict
|
|
if "loss_class" not in self.weight_dict:
|
|
self.weight_dict["loss_class"] = 0.0
|
|
|
|
self.focal_alpha_obj_score = focal_alpha_obj_score
|
|
self.focal_gamma_obj_score = focal_gamma_obj_score
|
|
self.supervise_all_iou = supervise_all_iou
|
|
self.iou_use_l1_loss = iou_use_l1_loss
|
|
self.pred_obj_scores = pred_obj_scores
|
|
|
|
def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
|
|
assert len(outs_batch) == len(targets_batch)
|
|
num_objects = torch.tensor(
|
|
(targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
|
|
) # Number of objects is fixed within a batch
|
|
if is_dist_avail_and_initialized():
|
|
torch.distributed.all_reduce(num_objects)
|
|
num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()
|
|
|
|
losses = defaultdict(int)
|
|
for outs, targets in zip(outs_batch, targets_batch):
|
|
cur_losses = self._forward(outs, targets, num_objects)
|
|
for k, v in cur_losses.items():
|
|
losses[k] += v
|
|
|
|
return losses
|
|
|
|
def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
|
|
"""
|
|
Compute the losses related to the masks: the focal loss and the dice loss.
|
|
and also the MAE or MSE loss between predicted IoUs and actual IoUs.
|
|
|
|
Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
|
|
of shape [N, M, H, W], where M could be 1 or larger, corresponding to
|
|
one or multiple predicted masks from a click.
|
|
|
|
We back-propagate focal, dice losses only on the prediction channel
|
|
with the lowest focal+dice loss between predicted mask and ground-truth.
|
|
If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
|
|
"""
|
|
|
|
target_masks = targets.unsqueeze(1).float()
|
|
assert target_masks.dim() == 4 # [N, 1, H, W]
|
|
src_masks_list = outputs["multistep_pred_multimasks_high_res"]
|
|
ious_list = outputs["multistep_pred_ious"]
|
|
object_score_logits_list = outputs["multistep_object_score_logits"]
|
|
|
|
assert len(src_masks_list) == len(ious_list)
|
|
assert len(object_score_logits_list) == len(ious_list)
|
|
|
|
# accumulate the loss over prediction steps
|
|
losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
|
|
for src_masks, ious, object_score_logits in zip(
|
|
src_masks_list, ious_list, object_score_logits_list
|
|
):
|
|
self._update_losses(
|
|
losses, src_masks, target_masks, ious, num_objects, object_score_logits
|
|
)
|
|
losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
|
|
return losses
|
|
|
|
def _update_losses(
|
|
self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
|
|
):
|
|
target_masks = target_masks.expand_as(src_masks)
|
|
# get focal, dice and iou loss on all output masks in a prediction step
|
|
loss_multimask = sigmoid_focal_loss(
|
|
src_masks,
|
|
target_masks,
|
|
num_objects,
|
|
alpha=self.focal_alpha,
|
|
gamma=self.focal_gamma,
|
|
loss_on_multimask=True,
|
|
)
|
|
loss_multidice = dice_loss(
|
|
src_masks, target_masks, num_objects, loss_on_multimask=True
|
|
)
|
|
if not self.pred_obj_scores:
|
|
loss_class = torch.tensor(
|
|
0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
|
|
)
|
|
target_obj = torch.ones(
|
|
loss_multimask.shape[0],
|
|
1,
|
|
dtype=loss_multimask.dtype,
|
|
device=loss_multimask.device,
|
|
)
|
|
else:
|
|
target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
|
|
..., None
|
|
].float()
|
|
loss_class = sigmoid_focal_loss(
|
|
object_score_logits,
|
|
target_obj,
|
|
num_objects,
|
|
alpha=self.focal_alpha_obj_score,
|
|
gamma=self.focal_gamma_obj_score,
|
|
)
|
|
|
|
loss_multiiou = iou_loss(
|
|
src_masks,
|
|
target_masks,
|
|
ious,
|
|
num_objects,
|
|
loss_on_multimask=True,
|
|
use_l1_loss=self.iou_use_l1_loss,
|
|
)
|
|
assert loss_multimask.dim() == 2
|
|
assert loss_multidice.dim() == 2
|
|
assert loss_multiiou.dim() == 2
|
|
if loss_multimask.size(1) > 1:
|
|
# take the mask indices with the smallest focal + dice loss for back propagation
|
|
loss_combo = (
|
|
loss_multimask * self.weight_dict["loss_mask"]
|
|
+ loss_multidice * self.weight_dict["loss_dice"]
|
|
)
|
|
best_loss_inds = torch.argmin(loss_combo, dim=-1)
|
|
batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
|
|
loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
|
|
loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
|
|
# calculate the iou prediction and slot losses only in the index
|
|
# with the minimum loss for each mask (to be consistent w/ SAM)
|
|
if self.supervise_all_iou:
|
|
loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
|
|
else:
|
|
loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
|
|
else:
|
|
loss_mask = loss_multimask
|
|
loss_dice = loss_multidice
|
|
loss_iou = loss_multiiou
|
|
|
|
# backprop focal, dice and iou loss only if obj present
|
|
loss_mask = loss_mask * target_obj
|
|
loss_dice = loss_dice * target_obj
|
|
loss_iou = loss_iou * target_obj
|
|
|
|
# sum over batch dimension (note that the losses are already divided by num_objects)
|
|
losses["loss_mask"] += loss_mask.sum()
|
|
losses["loss_dice"] += loss_dice.sum()
|
|
losses["loss_iou"] += loss_iou.sum()
|
|
losses["loss_class"] += loss_class
|
|
|
|
def reduce_loss(self, losses):
|
|
reduced_loss = 0.0
|
|
for loss_key, weight in self.weight_dict.items():
|
|
if loss_key not in losses:
|
|
raise ValueError(f"{type(self)} doesn't compute {loss_key}")
|
|
if weight != 0:
|
|
reduced_loss += losses[loss_key] * weight
|
|
|
|
return reduced_loss
|