[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
@@ -177,6 +177,47 @@ def load_video_frames(
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Load the video frames from video_path. The frames are resized to image_size as in
|
||||
the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo.
|
||||
"""
|
||||
is_bytes = isinstance(video_path, bytes)
|
||||
is_str = isinstance(video_path, str)
|
||||
is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
|
||||
if is_bytes or is_mp4_path:
|
||||
return load_video_frames_from_video_file(
|
||||
video_path=video_path,
|
||||
image_size=image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
img_mean=img_mean,
|
||||
img_std=img_std,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
elif is_str and os.path.isdir(video_path):
|
||||
return load_video_frames_from_jpg_images(
|
||||
video_path=video_path,
|
||||
image_size=image_size,
|
||||
offload_video_to_cpu=offload_video_to_cpu,
|
||||
img_mean=img_mean,
|
||||
img_std=img_std,
|
||||
async_loading_frames=async_loading_frames,
|
||||
compute_device=compute_device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Only MP4 video and JPEG folder are supported at this moment"
|
||||
)
|
||||
|
||||
|
||||
def load_video_frames_from_jpg_images(
|
||||
video_path,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
async_loading_frames=False,
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""
|
||||
Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
|
||||
@@ -236,6 +277,38 @@ def load_video_frames(
|
||||
return images, video_height, video_width
|
||||
|
||||
|
||||
def load_video_frames_from_video_file(
|
||||
video_path,
|
||||
image_size,
|
||||
offload_video_to_cpu,
|
||||
img_mean=(0.485, 0.456, 0.406),
|
||||
img_std=(0.229, 0.224, 0.225),
|
||||
compute_device=torch.device("cuda"),
|
||||
):
|
||||
"""Load the video frames from a video file."""
|
||||
import decord
|
||||
|
||||
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
|
||||
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
|
||||
# Get the original video height and width
|
||||
decord.bridge.set_bridge("torch")
|
||||
video_height, video_width, _ = decord.VideoReader(video_path).next().shape
|
||||
# Iterate over all frames in the video
|
||||
images = []
|
||||
for frame in decord.VideoReader(video_path, width=image_size, height=image_size):
|
||||
images.append(frame.permute(2, 0, 1))
|
||||
|
||||
images = torch.stack(images, dim=0).float() / 255.0
|
||||
if not offload_video_to_cpu:
|
||||
images = images.to(compute_device)
|
||||
img_mean = img_mean.to(compute_device)
|
||||
img_std = img_std.to(compute_device)
|
||||
# normalize by mean and std
|
||||
images -= img_mean
|
||||
images /= img_std
|
||||
return images, video_height, video_width
|
||||
|
||||
|
||||
def fill_holes_in_mask_scores(mask, max_area):
|
||||
"""
|
||||
A post processor to fill small holes in mask scores with area under `max_area`.
|
||||
@@ -256,7 +329,7 @@ def fill_holes_in_mask_scores(mask, max_area):
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
@@ -108,7 +108,7 @@ class SAM2Transforms(nn.Module):
|
||||
f"{e}\n\nSkipping the post-processing step due to the error above. You can "
|
||||
"still use SAM 2 and it's OK to ignore the error above, although some post-processing "
|
||||
"functionality may be limited (which doesn't affect the results in most cases; see "
|
||||
"https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
|
||||
"https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
Reference in New Issue
Block a user