123 lines
4.8 KiB
Python
123 lines
4.8 KiB
Python
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import gc
|
|
import sys
|
|
sys.path.append("./sam2")
|
|
from sam2.build_sam import build_sam2_video_predictor
|
|
|
|
color = [(255, 0, 0)]
|
|
|
|
def load_txt(gt_path):
|
|
with open(gt_path, 'r') as f:
|
|
gt = f.readlines()
|
|
prompts = {}
|
|
for fid, line in enumerate(gt):
|
|
x, y, w, h = map(float, line.split(','))
|
|
x, y, w, h = int(x), int(y), int(w), int(h)
|
|
prompts[fid] = ((x, y, x + w, y + h), 0)
|
|
return prompts
|
|
|
|
def determine_model_cfg(model_path):
|
|
if "large" in model_path:
|
|
return "configs/samurai/sam2.1_hiera_l.yaml"
|
|
elif "base_plus" in model_path:
|
|
return "configs/samurai/sam2.1_hiera_b+.yaml"
|
|
elif "small" in model_path:
|
|
return "configs/samurai/sam2.1_hiera_s.yaml"
|
|
elif "tiny" in model_path:
|
|
return "configs/samurai/sam2.1_hiera_t.yaml"
|
|
else:
|
|
raise ValueError("Unknown model size in path!")
|
|
|
|
def prepare_frames_or_path(video_path):
|
|
if video_path.endswith(".mp4") or osp.isdir(video_path):
|
|
return video_path
|
|
else:
|
|
raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")
|
|
|
|
def main(args):
|
|
model_cfg = determine_model_cfg(args.model_path)
|
|
predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
|
|
frames_or_path = prepare_frames_or_path(args.video_path)
|
|
prompts = load_txt(args.txt_path)
|
|
|
|
frame_rate = 30
|
|
if args.save_to_video:
|
|
if osp.isdir(args.video_path):
|
|
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith((".jpg", ".jpeg", ".JPG", ".JPEG"))])
|
|
loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
|
|
height, width = loaded_frames[0].shape[:2]
|
|
else:
|
|
cap = cv2.VideoCapture(args.video_path)
|
|
frame_rate = cap.get(cv2.CAP_PROP_FPS)
|
|
loaded_frames = []
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
loaded_frames.append(frame)
|
|
cap.release()
|
|
height, width = loaded_frames[0].shape[:2]
|
|
|
|
if len(loaded_frames) == 0:
|
|
raise ValueError("No frames were loaded from the video.")
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
out = cv2.VideoWriter(args.video_output_path, fourcc, frame_rate, (width, height))
|
|
|
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
|
state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
|
|
bbox, track_label = prompts[0]
|
|
_, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
|
|
|
|
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
|
mask_to_vis = {}
|
|
bbox_to_vis = {}
|
|
|
|
for obj_id, mask in zip(object_ids, masks):
|
|
mask = mask[0].cpu().numpy()
|
|
mask = mask > 0.0
|
|
non_zero_indices = np.argwhere(mask)
|
|
if len(non_zero_indices) == 0:
|
|
bbox = [0, 0, 0, 0]
|
|
else:
|
|
y_min, x_min = non_zero_indices.min(axis=0).tolist()
|
|
y_max, x_max = non_zero_indices.max(axis=0).tolist()
|
|
bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
|
|
bbox_to_vis[obj_id] = bbox
|
|
mask_to_vis[obj_id] = mask
|
|
|
|
if args.save_to_video:
|
|
img = loaded_frames[frame_idx]
|
|
for obj_id, mask in mask_to_vis.items():
|
|
mask_img = np.zeros((height, width, 3), np.uint8)
|
|
mask_img[mask] = color[(obj_id + 1) % len(color)]
|
|
img = cv2.addWeighted(img, 1, mask_img, 0.2, 0)
|
|
|
|
for obj_id, bbox in bbox_to_vis.items():
|
|
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color[obj_id % len(color)], 2)
|
|
|
|
out.write(img)
|
|
|
|
if args.save_to_video:
|
|
out.release()
|
|
|
|
del predictor, state
|
|
gc.collect()
|
|
torch.clear_autocast_cache()
|
|
torch.cuda.empty_cache()
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--video_path", required=True, help="Input video path or directory of frames.")
|
|
parser.add_argument("--txt_path", required=True, help="Path to ground truth text file.")
|
|
parser.add_argument("--model_path", default="sam2/checkpoints/sam2.1_hiera_base_plus.pt", help="Path to the model checkpoint.")
|
|
parser.add_argument("--video_output_path", default="demo.mp4", help="Path to save the output video.")
|
|
parser.add_argument("--save_to_video", default=True, help="Save results to a video.")
|
|
args = parser.parse_args()
|
|
main(args)
|