Add files via upload
This commit is contained in:
120
scripts/demo.py
Normal file
120
scripts/demo.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import gc
|
||||
import sys
|
||||
sys.path.append("./sam2")
|
||||
from sam2.build_sam import build_sam2_video_predictor
|
||||
|
||||
color = [(0, 0, 255)]
|
||||
|
||||
def load_txt(gt_path):
|
||||
with open(gt_path, 'r') as f:
|
||||
gt = f.readlines()
|
||||
prompts = {}
|
||||
for fid, line in enumerate(gt):
|
||||
x, y, w, h = map(float, line.split(','))
|
||||
x, y, w, h = int(x), int(y), int(w), int(h)
|
||||
prompts[fid] = ((x, y, x + w, y + h), 0)
|
||||
return prompts
|
||||
|
||||
def determine_model_cfg(model_path):
|
||||
if "large" in model_path:
|
||||
return "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
elif "base_plus" in model_path:
|
||||
return "configs/sam2.1/sam2.1_hiera_b+.yaml"
|
||||
elif "small" in model_path:
|
||||
return "configs/sam2.1/sam2.1_hiera_s.yaml"
|
||||
elif "tiny" in model_path:
|
||||
return "configs/sam2.1/sam2.1_hiera_t.yaml"
|
||||
else:
|
||||
raise ValueError("Unknown model size in path!")
|
||||
|
||||
def prepare_frames_or_path(video_path):
|
||||
if video_path.endswith(".mp4") or osp.isdir(video_path):
|
||||
return video_path
|
||||
else:
|
||||
raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")
|
||||
|
||||
def main(args):
|
||||
model_cfg = determine_model_cfg(args.model_path)
|
||||
predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
|
||||
frames_or_path = prepare_frames_or_path(args.video_path)
|
||||
prompts = load_txt(args.txt_path)
|
||||
|
||||
if args.save_to_video:
|
||||
if osp.isdir(args.video_path):
|
||||
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
|
||||
loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
|
||||
height, width = loaded_frames[0].shape[:2]
|
||||
else:
|
||||
cap = cv2.VideoCapture(args.video_path)
|
||||
loaded_frames = []
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
loaded_frames.append(frame)
|
||||
cap.release()
|
||||
height, width = loaded_frames[0].shape[:2]
|
||||
|
||||
if len(loaded_frames) == 0:
|
||||
raise ValueError("No frames were loaded from the video.")
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height))
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
|
||||
bbox, track_label = prompts[0]
|
||||
_, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
|
||||
|
||||
for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
|
||||
mask_to_vis = {}
|
||||
bbox_to_vis = {}
|
||||
|
||||
for obj_id, mask in zip(object_ids, masks):
|
||||
mask = mask[0].cpu().numpy()
|
||||
mask = mask > 0.0
|
||||
non_zero_indices = np.argwhere(mask)
|
||||
if len(non_zero_indices) == 0:
|
||||
bbox = [0, 0, 0, 0]
|
||||
else:
|
||||
y_min, x_min = non_zero_indices.min(axis=0).tolist()
|
||||
y_max, x_max = non_zero_indices.max(axis=0).tolist()
|
||||
bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
|
||||
bbox_to_vis[obj_id] = bbox
|
||||
mask_to_vis[obj_id] = mask
|
||||
|
||||
if args.save_to_video:
|
||||
img = loaded_frames[frame_idx]
|
||||
for obj_id, mask in mask_to_vis.items():
|
||||
mask_img = np.zeros((height, width, 3), np.uint8)
|
||||
mask_img[mask] = color[(obj_id + 1) % len(color)]
|
||||
img = cv2.addWeighted(img, 1, mask_img, 0.2, 0)
|
||||
|
||||
for obj_id, bbox in bbox_to_vis.items():
|
||||
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color[obj_id % len(color)], 2)
|
||||
|
||||
out.write(img)
|
||||
|
||||
if args.save_to_video:
|
||||
out.release()
|
||||
|
||||
del predictor, state
|
||||
gc.collect()
|
||||
torch.clear_autocast_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--video_path", required=True, help="Input video path or directory of frames.")
|
||||
parser.add_argument("--txt_path", required=True, help="Path to ground truth text file.")
|
||||
parser.add_argument("--model_path", default="sam2/checkpoints/sam2.1_hiera_base_plus.pt", help="Path to the model checkpoint.")
|
||||
parser.add_argument("--video_output_path", default="demo.mp4", help="Path to save the output video.")
|
||||
parser.add_argument("--save_to_video", default=True, help="Save results to a video.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
Reference in New Issue
Block a user