update to the latest sam2 version and support box prompts in video tracking
This commit is contained in:
@@ -220,10 +220,24 @@ def fill_holes_in_mask_scores(mask, max_area):
|
||||
# Holes are those connected components in background with area <= self.max_area
|
||||
# (background regions are those with mask scores <= 0)
|
||||
assert max_area > 0, "max_area must be positive"
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
|
||||
input_mask = mask
|
||||
try:
|
||||
labels, areas = get_connected_components(mask <= 0)
|
||||
is_hole = (labels > 0) & (areas <= max_area)
|
||||
# We fill holes with a small positive mask score (0.1) to change them to foreground.
|
||||
mask = torch.where(is_hole, 0.1, mask)
|
||||
except Exception as e:
|
||||
# Skip the post-processing step on removing small holes if the CUDA kernel fails
|
||||
warnings.warn(
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. "
|
||||
"Consider building SAM 2 with CUDA extension to enable post-processing (see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
mask = input_mask
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user