[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
@@ -6,11 +6,15 @@
|
||||
|
||||
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sam2.utils.misc import mask_to_box
|
||||
|
||||
|
||||
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
|
||||
"""
|
||||
@@ -147,3 +151,173 @@ class LayerNorm2d(nn.Module):
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
def sample_box_points(
|
||||
masks: torch.Tensor,
|
||||
noise: float = 0.1, # SAM default
|
||||
noise_bound: int = 20, # SAM default
|
||||
top_left_label: int = 2,
|
||||
bottom_right_label: int = 3,
|
||||
) -> Tuple[np.array, np.array]:
|
||||
"""
|
||||
Sample a noised version of the top left and bottom right corners of a given `bbox`
|
||||
|
||||
Inputs:
|
||||
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
|
||||
- noise: noise as a fraction of box width and height, dtype=float
|
||||
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
|
||||
|
||||
Returns:
|
||||
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
|
||||
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
|
||||
"""
|
||||
device = masks.device
|
||||
box_coords = mask_to_box(masks)
|
||||
B, _, H, W = masks.shape
|
||||
box_labels = torch.tensor(
|
||||
[top_left_label, bottom_right_label], dtype=torch.int, device=device
|
||||
).repeat(B)
|
||||
if noise > 0.0:
|
||||
if not isinstance(noise_bound, torch.Tensor):
|
||||
noise_bound = torch.tensor(noise_bound, device=device)
|
||||
bbox_w = box_coords[..., 2] - box_coords[..., 0]
|
||||
bbox_h = box_coords[..., 3] - box_coords[..., 1]
|
||||
max_dx = torch.min(bbox_w * noise, noise_bound)
|
||||
max_dy = torch.min(bbox_h * noise, noise_bound)
|
||||
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
|
||||
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
|
||||
|
||||
box_coords = box_coords + box_noise
|
||||
img_bounds = (
|
||||
torch.tensor([W, H, W, H], device=device) - 1
|
||||
) # uncentered pixel coords
|
||||
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
|
||||
|
||||
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
|
||||
box_labels = box_labels.reshape(-1, 2)
|
||||
return box_coords, box_labels
|
||||
|
||||
|
||||
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
|
||||
"""
|
||||
Sample `num_pt` random points (along with their labels) independently from the error regions.
|
||||
|
||||
Inputs:
|
||||
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
|
||||
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
|
||||
- num_pt: int, number of points to sample independently for each of the B error maps
|
||||
|
||||
Outputs:
|
||||
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
|
||||
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
|
||||
negative clicks
|
||||
"""
|
||||
if pred_masks is None: # if pred_masks is not provided, treat it as empty
|
||||
pred_masks = torch.zeros_like(gt_masks)
|
||||
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
|
||||
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
|
||||
assert num_pt >= 0
|
||||
|
||||
B, _, H_im, W_im = gt_masks.shape
|
||||
device = gt_masks.device
|
||||
|
||||
# false positive region, a new point sampled in this region should have
|
||||
# negative label to correct the FP error
|
||||
fp_masks = ~gt_masks & pred_masks
|
||||
# false negative region, a new point sampled in this region should have
|
||||
# positive label to correct the FN error
|
||||
fn_masks = gt_masks & ~pred_masks
|
||||
# whether the prediction completely match the ground-truth on each mask
|
||||
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
|
||||
all_correct = all_correct[..., None, None]
|
||||
|
||||
# channel 0 is FP map, while channel 1 is FN map
|
||||
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
|
||||
# sample a negative new click from FP region or a positive new click
|
||||
# from FN region, depend on where the maximum falls,
|
||||
# and in case the predictions are all correct (no FP or FN), we just
|
||||
# sample a negative click from the background region
|
||||
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
|
||||
pts_noise[..., 1] *= fn_masks
|
||||
pts_idx = pts_noise.flatten(2).argmax(dim=2)
|
||||
labels = (pts_idx % 2).to(torch.int32)
|
||||
pts_idx = pts_idx // 2
|
||||
pts_x = pts_idx % W_im
|
||||
pts_y = pts_idx // W_im
|
||||
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
|
||||
return points, labels
|
||||
|
||||
|
||||
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
|
||||
"""
|
||||
Sample 1 random point (along with its label) from the center of each error region,
|
||||
that is, the point with the largest distance to the boundary of each error region.
|
||||
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
|
||||
|
||||
Inputs:
|
||||
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
|
||||
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
|
||||
- padding: if True, pad with boundary of 1 px for distance transform
|
||||
|
||||
Outputs:
|
||||
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
|
||||
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
|
||||
"""
|
||||
import cv2
|
||||
|
||||
if pred_masks is None:
|
||||
pred_masks = torch.zeros_like(gt_masks)
|
||||
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
|
||||
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
|
||||
|
||||
B, _, _, W_im = gt_masks.shape
|
||||
device = gt_masks.device
|
||||
|
||||
# false positive region, a new point sampled in this region should have
|
||||
# negative label to correct the FP error
|
||||
fp_masks = ~gt_masks & pred_masks
|
||||
# false negative region, a new point sampled in this region should have
|
||||
# positive label to correct the FN error
|
||||
fn_masks = gt_masks & ~pred_masks
|
||||
|
||||
fp_masks = fp_masks.cpu().numpy()
|
||||
fn_masks = fn_masks.cpu().numpy()
|
||||
points = torch.zeros(B, 1, 2, dtype=torch.float)
|
||||
labels = torch.ones(B, 1, dtype=torch.int32)
|
||||
for b in range(B):
|
||||
fn_mask = fn_masks[b, 0]
|
||||
fp_mask = fp_masks[b, 0]
|
||||
if padding:
|
||||
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
|
||||
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
|
||||
# compute the distance of each point in FN/FP region to its boundary
|
||||
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
||||
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
|
||||
if padding:
|
||||
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
|
||||
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
|
||||
|
||||
# take the point in FN/FP region with the largest distance to its boundary
|
||||
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
|
||||
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
|
||||
fn_argmax = np.argmax(fn_mask_dt_flat)
|
||||
fp_argmax = np.argmax(fp_mask_dt_flat)
|
||||
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
|
||||
pt_idx = fn_argmax if is_positive else fp_argmax
|
||||
points[b, 0, 0] = pt_idx % W_im # x
|
||||
points[b, 0, 1] = pt_idx // W_im # y
|
||||
labels[b, 0] = int(is_positive)
|
||||
|
||||
points = points.to(device)
|
||||
labels = labels.to(device)
|
||||
return points, labels
|
||||
|
||||
|
||||
def get_next_point(gt_masks, pred_masks, method):
|
||||
if method == "uniform":
|
||||
return sample_random_points_from_errors(gt_masks, pred_masks)
|
||||
elif method == "center":
|
||||
return sample_one_point_from_error_center(gt_masks, pred_masks)
|
||||
else:
|
||||
raise ValueError(f"unknown sampling method {method}")
|
||||
|
Reference in New Issue
Block a user