update to the latest sam2 version and support box prompts in video tracking

This commit is contained in:
rentainhe
2024-08-08 12:03:29 +08:00
parent 96cbab92e0
commit 077064c365
8 changed files with 272 additions and 45 deletions

View File

@@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area):
# Holes are those connected components in background with area <= self.max_area
# (background regions are those with mask scores <= 0)
assert max_area > 0, "max_area must be positive"
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
input_mask = mask
try:
labels, areas = get_connected_components(mask <= 0)
is_hole = (labels > 0) & (areas <= max_area)
# We fill holes with a small positive mask score (0.1) to change them to foreground.
mask = torch.where(is_hole, 0.1, mask)
except Exception as e:
# Skip the post-processing step on removing small holes if the CUDA kernel fails
warnings.warn(
f"{e}\n\nSkipping the post-processing step due to the error above. "
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
category=UserWarning,
stacklevel=2,
)
mask = input_mask
return mask