support gsam2 image predictor model
This commit is contained in:
5
sam2/utils/__init__.py
Normal file
5
sam2/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
348
sam2/utils/amg.py
Normal file
348
sam2/utils/amg.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
|
||||
|
||||
|
||||
class MaskData:
|
||||
"""
|
||||
A structure for storing masks and their related data in batched format.
|
||||
Implements basic filtering and concatenation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
for v in kwargs.values():
|
||||
assert isinstance(
|
||||
v, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats = dict(**kwargs)
|
||||
|
||||
def __setitem__(self, key: str, item: Any) -> None:
|
||||
assert isinstance(
|
||||
item, (list, np.ndarray, torch.Tensor)
|
||||
), "MaskData only supports list, numpy arrays, and torch tensors."
|
||||
self._stats[key] = item
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._stats[key]
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._stats[key]
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
return self._stats.items()
|
||||
|
||||
def filter(self, keep: torch.Tensor) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if v is None:
|
||||
self._stats[k] = None
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = v[keep.detach().cpu().numpy()]
|
||||
elif isinstance(v, list) and keep.dtype == torch.bool:
|
||||
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = [v[i] for i in keep]
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def cat(self, new_stats: "MaskData") -> None:
|
||||
for k, v in new_stats.items():
|
||||
if k not in self._stats or self._stats[k] is None:
|
||||
self._stats[k] = deepcopy(v)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
||||
elif isinstance(v, np.ndarray):
|
||||
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
||||
elif isinstance(v, list):
|
||||
self._stats[k] = self._stats[k] + deepcopy(v)
|
||||
else:
|
||||
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
for k, v in self._stats.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self._stats[k] = v.float().detach().cpu().numpy()
|
||||
|
||||
|
||||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
||||
box_xywh = deepcopy(box_xyxy)
|
||||
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
||||
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
||||
return box_xywh
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
assert len(args) > 0 and all(
|
||||
len(a) == len(args[0]) for a in args
|
||||
), "Batched iteration must have inputs of all the same size."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Encodes masks to an uncompressed RLE, in the format expected by
|
||||
pycoco tools.
|
||||
"""
|
||||
# Put in fortran order and flatten h,w
|
||||
b, h, w = tensor.shape
|
||||
tensor = tensor.permute(0, 2, 1).flatten(1)
|
||||
|
||||
# Compute change indices
|
||||
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
||||
change_indices = diff.nonzero()
|
||||
|
||||
# Encode run length
|
||||
out = []
|
||||
for i in range(b):
|
||||
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
||||
cur_idxs = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
cur_idxs + 1,
|
||||
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
||||
]
|
||||
)
|
||||
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
||||
counts = [] if tensor[i, 0] == 0 else [0]
|
||||
counts.extend(btw_idxs.detach().cpu().tolist())
|
||||
out.append({"size": [h, w], "counts": counts})
|
||||
return out
|
||||
|
||||
|
||||
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
||||
"""Compute a binary mask from an uncompressed RLE."""
|
||||
h, w = rle["size"]
|
||||
mask = np.empty(h * w, dtype=bool)
|
||||
idx = 0
|
||||
parity = False
|
||||
for count in rle["counts"]:
|
||||
mask[idx : idx + count] = parity
|
||||
idx += count
|
||||
parity ^= True
|
||||
mask = mask.reshape(w, h)
|
||||
return mask.transpose() # Put in C order
|
||||
|
||||
|
||||
def area_from_rle(rle: Dict[str, Any]) -> int:
|
||||
return sum(rle["counts"][1::2])
|
||||
|
||||
|
||||
def calculate_stability_score(
|
||||
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the stability score for a batch of masks. The stability
|
||||
score is the IoU between the binary masks obtained by thresholding
|
||||
the predicted mask logits at high and low values.
|
||||
"""
|
||||
# One mask is always contained inside the other.
|
||||
# Save memory by preventing unnecessary cast to torch.int64
|
||||
intersections = (
|
||||
(masks > (mask_threshold + threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
unions = (
|
||||
(masks > (mask_threshold - threshold_offset))
|
||||
.sum(-1, dtype=torch.int16)
|
||||
.sum(-1, dtype=torch.int32)
|
||||
)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
return points
|
||||
|
||||
|
||||
def build_all_layer_point_grids(
|
||||
n_per_side: int, n_layers: int, scale_per_layer: int
|
||||
) -> List[np.ndarray]:
|
||||
"""Generates point grids for all crop layers."""
|
||||
points_by_layer = []
|
||||
for i in range(n_layers + 1):
|
||||
n_points = int(n_per_side / (scale_per_layer**i))
|
||||
points_by_layer.append(build_point_grid(n_points))
|
||||
return points_by_layer
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes. Each layer
|
||||
has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(
|
||||
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
||||
) -> torch.Tensor:
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
||||
from pycocotools import mask as mask_utils # type: ignore
|
||||
|
||||
h, w = uncompressed_rle["size"]
|
||||
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
||||
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
||||
return rle
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
||||
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
if len(shape) > 2:
|
||||
masks = masks.flatten(0, -3)
|
||||
else:
|
||||
masks = masks.unsqueeze(0)
|
||||
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
if len(shape) > 2:
|
||||
out = out.reshape(*shape[:-2], 4)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
return out
|
238
sam2/utils/misc.py
Normal file
238
sam2/utils/misc.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_sdpa_settings():
|
||||
if torch.cuda.is_available():
|
||||
old_gpu = torch.cuda.get_device_properties(0).major < 7
|
||||
# only use Flash Attention on Ampere (8.0) or newer GPUs
|
||||
use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
|
||||
if not use_flash_attn:
|
||||
warnings.warn(
|
||||
"Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
|
||||
# available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
|
||||
pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
|
||||
if pytorch_version < (2, 2):
|
||||
warnings.warn(
|
||||
f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
|
||||
"Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
|
||||
else:
|
||||
old_gpu = True
|
||||
use_flash_attn = False
|
||||
math_kernel_on = True
|
||||
|
||||
return old_gpu, use_flash_attn, math_kernel_on
|
||||
|
||||
|
||||
def get_connected_components(mask):
|
||||
"""
|
||||
Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
|
||||
|
||||
Inputs:
|
||||
- mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
|
||||
background.
|
||||
|
||||
Outputs:
|
||||
- labels: A tensor of shape (N, 1, H, W) containing the connected component labels
|
||||
for foreground pixels and 0 for background pixels.
|
||||
- counts: A tensor of shape (N, 1, H, W) containing the area of the connected
|
||||
components for foreground pixels and 0 for background pixels.
|
||||
"""
|
||||
from sam2 import _C
|
||||
|
||||
return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
|
||||
|
||||
|
||||
def mask_to_box(masks: torch.Tensor):
|
||||
"""
|
||||
compute bounding box given an input mask
|
||||
|
||||
Inputs:
|
||||
- masks: [B, 1, H, W] boxes, dtype=torch.Tensor
|
||||
|
||||
Returns:
|
||||
- box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
|
||||
"""
|
||||
B, _, h, w = masks.shape
|
||||
device = masks.device
|
||||
xs = torch.arange(w, device=device, dtype=torch.int32)
|
||||
ys = torch.arange(h, device=device, dtype=torch.int32)
|
||||
grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
|
||||
grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
|
||||
grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
|
||||
min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
|
||||
max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
|
||||
min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
|
||||
max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
|
||||
bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
|
||||
|
||||
return bbox_coords
|
||||
|
||||
|
||||
def _load_img_as_tensor(img_path, image_size):
|
||||
img_pil = Image.open(img_path)
|
||||
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
|
||||
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
|
||||
img_np = img_np / 255.0
|
||||
else:
|
||||
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
|
||||
img = torch.from_numpy(img_np).permute(2, 0, 1)
|
||||
video_width, video_height = img_pil.size # the original video size
|
||||
return img, video_height, video_width
|
||||
|
||||
|
||||
class AsyncVideoFrameLoader:
|
||||
"""
|
||||
A list of video frames to be load asynchronously without blocking session start.
|
||||
"""
|
||||
|
||||
def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
|
||||
self.img_paths = img_paths
|
||||
self.image_size = image_size
|
||||
self.offload_video_to_cpu = offload_video_to_cpu
|
||||
self.img_mean = img_mean
|
||||
self.img_std = img_std
|
||||
# items in `self._images` will be loaded asynchronously
|
||||
self.images = [None] * len(img_paths)
|
||||
# catch and raise any exceptions in the async loading thread
|
||||
self.exception = None
|
||||
# video_height and video_width be filled when loading the first image
|
||||
self.video_height = None
|
||||
self.video_width = None
|
||||
|
||||
# load the first frame to fill video_height and video_width and also
|
||||
# to cache it (since it's most likely where the user will click)
|
||||
self.__getitem__(0)
|
||||
|
||||
# load the rest of frames asynchronously without blocking the session start
|
||||
def _load_frames():
|
||||
try:
|
||||
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
|
||||
self.__getitem__(n)
|
||||
except Exception as e:
|
||||
self.exception = e
|
||||
|
||||
self.thread = Thread(target=_load_frames, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.exception is not None:
|
||||
raise RuntimeError("Failure in frame loading thread") from self.exception
|
||||
|
||||
img = self.images[index]
|
||||
if img is not None:
|
||||
return img
|
||||
|
||||
img, video_height, video_width = _load_img_as_tensor(
|
||||
self.img_paths[index], self.image_size
|
||||
)
|
||||
self.video_height = video_height
|
||||
self.video_width = video_width
|
||||
# normalize by mean and std
|
||||
img -= self.img_mean
|
||||
img /= self.img_std
|
||||
if not self.offload_video_to_cpu:
|
||||
img = img.cuda(non_blocking=True)
|
||||
self.images[index] = img
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
def load_video_frames(
|
||||
video_path,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
):
|
||||
"""
|
||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||
|
||||
The frames are resized to image_size x image_size and are loaded to GPU if
|
||||
`offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
|
||||
|
||||
You can load a frame asynchronously by setting `async_loading_frames` to `True`.
|
||||
"""
|
||||
if isinstance(video_path, str) and os.path.isdir(video_path):
|
||||
jpg_folder = video_path
|
||||
else:
|
||||
raise NotImplementedError("Only JPEG frames are supported at this moment")
|
||||
|
||||
frame_names = [
|
||||
p
|
||||
for p in os.listdir(jpg_folder)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
num_frames = len(frame_names)
|
||||
if num_frames == 0:
|
||||
raise RuntimeError(f"no images found in {jpg_folder}")
|
||||
img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
|
||||
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
|
||||
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
|
||||
|
||||
if async_loading_frames:
|
||||
lazy_images = AsyncVideoFrameLoader(
|
||||
img_paths, image_size, offload_video_to_cpu, img_mean, img_std
|
||||
)
|
||||
return lazy_images, lazy_images.video_height, lazy_images.video_width
|
||||
|
||||
images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
|
||||
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
|
||||
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
|
||||
if not offload_video_to_cpu:
|
||||
images = images.cuda()
|
||||
img_mean = img_mean.cuda()
|
||||
img_std = img_std.cuda()
|
||||
# normalize by mean and std
|
||||
images -= img_mean
|
||||
images /= img_std
|
||||
return images, video_height, video_width
|
||||
|
||||
|
||||
def fill_holes_in_mask_scores(mask, max_area):
|
||||
"""
|
||||
A post processor to fill small holes in mask scores with area under `max_area`.
|
||||
"""
|
||||
# Holes are those connected components in background with area <= self.max_area
|
||||
# (background regions are those with mask scores <= 0)
|
||||
assert max_area > 0, "max_area must be positive"
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
return mask
|
||||
|
||||
|
||||
def concat_points(old_point_inputs, new_points, new_labels):
|
||||
"""Add new points and labels to previous point inputs (add at the end)."""
|
||||
if old_point_inputs is None:
|
||||
points, labels = new_points, new_labels
|
||||
else:
|
||||
points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
|
||||
labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
|
||||
|
||||
return {"point_coords": points, "point_labels": labels}
|
99
sam2/utils/transforms.py
Normal file
99
sam2/utils/transforms.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.transforms import Normalize, Resize, ToTensor
|
||||
|
||||
|
||||
class SAM2Transforms(nn.Module):
|
||||
def __init__(
|
||||
self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
|
||||
):
|
||||
"""
|
||||
Transforms for SAM2.
|
||||
"""
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.mask_threshold = mask_threshold
|
||||
self.max_hole_area = max_hole_area
|
||||
self.max_sprinkle_area = max_sprinkle_area
|
||||
self.mean = [0.485, 0.456, 0.406]
|
||||
self.std = [0.229, 0.224, 0.225]
|
||||
self.to_tensor = ToTensor()
|
||||
self.transforms = torch.jit.script(
|
||||
nn.Sequential(
|
||||
Resize((self.resolution, self.resolution)),
|
||||
Normalize(self.mean, self.std),
|
||||
)
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.to_tensor(x)
|
||||
return self.transforms(x)
|
||||
|
||||
def forward_batch(self, img_list):
|
||||
img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
|
||||
img_batch = torch.stack(img_batch, dim=0)
|
||||
return img_batch
|
||||
|
||||
def transform_coords(
|
||||
self, coords: torch.Tensor, normalize=False, orig_hw=None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
|
||||
If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
||||
|
||||
Returns
|
||||
Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
|
||||
"""
|
||||
if normalize:
|
||||
assert orig_hw is not None
|
||||
h, w = orig_hw
|
||||
coords = coords.clone()
|
||||
coords[..., 0] = coords[..., 0] / w
|
||||
coords[..., 1] = coords[..., 1] / h
|
||||
|
||||
coords = coords * self.resolution # unnormalize coords
|
||||
return coords
|
||||
|
||||
def transform_boxes(
|
||||
self, boxes: torch.Tensor, normalize=False, orig_hw=None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
|
||||
if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
|
||||
"""
|
||||
boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
|
||||
return boxes
|
||||
|
||||
def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
|
||||
"""
|
||||
Perform PostProcessing on output masks.
|
||||
"""
|
||||
from sam2.utils.misc import get_connected_components
|
||||
|
||||
masks = masks.float()
|
||||
if self.max_hole_area > 0:
|
||||
# Holes are those connected components in background with area <= self.fill_hole_area
|
||||
# (background regions are those with mask scores <= self.mask_threshold)
|
||||
mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
|
||||
labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_hole_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with a small positive mask score (10.0) to change them to foreground.
|
||||
masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
|
||||
|
||||
if self.max_sprinkle_area > 0:
|
||||
labels, areas = get_connected_components(mask_flat > self.mask_threshold)
|
||||
is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
|
||||
is_hole = is_hole.reshape_as(masks)
|
||||
# We fill holes with negative mask score (-10.0) to change them to background.
|
||||
masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
|
||||
|
||||
masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
|
||||
return masks
|
Reference in New Issue
Block a user