[New Feature] Support SAM 2.1 (#59)

* support sam 2.1

* refine config path and ckpt path

* update README
This commit is contained in:
Ren Tianhe
2024-10-10 14:55:50 +08:00
committed by GitHub
parent e899ad99e8
commit 82e503604f
340 changed files with 39100 additions and 608 deletions

View File

@@ -6,11 +6,15 @@
import copy
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sam2.utils.misc import mask_to_box
def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
"""
@@ -147,3 +151,173 @@ class LayerNorm2d(nn.Module):
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def sample_box_points(
masks: torch.Tensor,
noise: float = 0.1, # SAM default
noise_bound: int = 20, # SAM default
top_left_label: int = 2,
bottom_right_label: int = 3,
) -> Tuple[np.array, np.array]:
"""
Sample a noised version of the top left and bottom right corners of a given `bbox`
Inputs:
- masks: [B, 1, H,W] boxes, dtype=torch.Tensor
- noise: noise as a fraction of box width and height, dtype=float
- noise_bound: maximum amount of noise (in pure pixesl), dtype=int
Returns:
- box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
- box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
"""
device = masks.device
box_coords = mask_to_box(masks)
B, _, H, W = masks.shape
box_labels = torch.tensor(
[top_left_label, bottom_right_label], dtype=torch.int, device=device
).repeat(B)
if noise > 0.0:
if not isinstance(noise_bound, torch.Tensor):
noise_bound = torch.tensor(noise_bound, device=device)
bbox_w = box_coords[..., 2] - box_coords[..., 0]
bbox_h = box_coords[..., 3] - box_coords[..., 1]
max_dx = torch.min(bbox_w * noise, noise_bound)
max_dy = torch.min(bbox_h * noise, noise_bound)
box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
box_coords = box_coords + box_noise
img_bounds = (
torch.tensor([W, H, W, H], device=device) - 1
) # uncentered pixel coords
box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
box_labels = box_labels.reshape(-1, 2)
return box_coords, box_labels
def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
"""
Sample `num_pt` random points (along with their labels) independently from the error regions.
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- num_pt: int, number of points to sample independently for each of the B error maps
Outputs:
- points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
negative clicks
"""
if pred_masks is None: # if pred_masks is not provided, treat it as empty
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
assert num_pt >= 0
B, _, H_im, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
# whether the prediction completely match the ground-truth on each mask
all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
all_correct = all_correct[..., None, None]
# channel 0 is FP map, while channel 1 is FN map
pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
# sample a negative new click from FP region or a positive new click
# from FN region, depend on where the maximum falls,
# and in case the predictions are all correct (no FP or FN), we just
# sample a negative click from the background region
pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
pts_noise[..., 1] *= fn_masks
pts_idx = pts_noise.flatten(2).argmax(dim=2)
labels = (pts_idx % 2).to(torch.int32)
pts_idx = pts_idx // 2
pts_x = pts_idx % W_im
pts_y = pts_idx // W_im
points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
return points, labels
def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
"""
Sample 1 random point (along with its label) from the center of each error region,
that is, the point with the largest distance to the boundary of each error region.
This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
Inputs:
- gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
- pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
- padding: if True, pad with boundary of 1 px for distance transform
Outputs:
- points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
- labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
"""
import cv2
if pred_masks is None:
pred_masks = torch.zeros_like(gt_masks)
assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
B, _, _, W_im = gt_masks.shape
device = gt_masks.device
# false positive region, a new point sampled in this region should have
# negative label to correct the FP error
fp_masks = ~gt_masks & pred_masks
# false negative region, a new point sampled in this region should have
# positive label to correct the FN error
fn_masks = gt_masks & ~pred_masks
fp_masks = fp_masks.cpu().numpy()
fn_masks = fn_masks.cpu().numpy()
points = torch.zeros(B, 1, 2, dtype=torch.float)
labels = torch.ones(B, 1, dtype=torch.int32)
for b in range(B):
fn_mask = fn_masks[b, 0]
fp_mask = fp_masks[b, 0]
if padding:
fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
# compute the distance of each point in FN/FP region to its boundary
fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
if padding:
fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
# take the point in FN/FP region with the largest distance to its boundary
fn_mask_dt_flat = fn_mask_dt.reshape(-1)
fp_mask_dt_flat = fp_mask_dt.reshape(-1)
fn_argmax = np.argmax(fn_mask_dt_flat)
fp_argmax = np.argmax(fp_mask_dt_flat)
is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
pt_idx = fn_argmax if is_positive else fp_argmax
points[b, 0, 0] = pt_idx % W_im # x
points[b, 0, 1] = pt_idx // W_im # y
labels[b, 0] = int(is_positive)
points = points.to(device)
labels = labels.to(device)
return points, labels
def get_next_point(gt_masks, pred_masks, method):
if method == "uniform":
return sample_random_points_from_errors(gt_masks, pred_masks)
elif method == "center":
return sample_one_point_from_error_center(gt_masks, pred_masks)
else:
raise ValueError(f"unknown sampling method {method}")