From 1d018ceb559f5f83ea11691b0d5edb643d19cfe9 Mon Sep 17 00:00:00 2001 From: rentainhe <596106517@qq.com> Date: Fri, 2 Aug 2024 17:06:32 +0800 Subject: [PATCH] add tracking demo and support video dump --- grounded_sam2_tracking_demo.py | 26 +++++++++++-------------- video_utils.py | 35 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 15 deletions(-) create mode 100644 video_utils.py diff --git a/grounded_sam2_tracking_demo.py b/grounded_sam2_tracking_demo.py index c3cc813..94edcdc 100644 --- a/grounded_sam2_tracking_demo.py +++ b/grounded_sam2_tracking_demo.py @@ -8,6 +8,7 @@ from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from track_utils import sample_points_from_masks +from video_utils import create_video_from_images """ @@ -152,21 +153,16 @@ for frame_idx, segments in video_segments.items(): mask=masks, # (n, h, w) class_id=np.array(object_ids, dtype=np.int32), ) + box_annotator = sv.BoxAnnotator() + annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids]) mask_annotator = sv.MaskAnnotator() - annotated_frame = mask_annotator.annotate(scene=img.copy(), detections=detections) - cv2.imwrite(f"annotated_frame_{frame_idx}.jpg", annotated_frame) + annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) + cv2.imwrite(os.path.join(save_dir, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame) -# import cv2 -# import supervision as sv -# # visualize each mask -# for out_frame_idx, masks in video_segments.items(): -# img = cv2.imread(os.path.join(video_dir, frame_names[out_frame_idx])) -# detections = sv.Detections( -# xyxy=np.array([[0, 0, 100, 100]]), # (n, 4) -# mask=masks[1] -# ) -# mask_annotator = sv.MaskAnnotator() -# annotated_frame = mask_annotator.annotate(scene=img.copy(), detections=detections) -# cv2.imwrite(f"annotated_frame_{out_frame_idx}.jpg", annotated_frame) -# import pdb; pdb.set_trace() +""" +Step 6: Convert the annotated frames to video +""" + +output_video_path = "./children_tracking_demo_video.mp4" +create_video_from_images(save_dir, output_video_path) \ No newline at end of file diff --git a/video_utils.py b/video_utils.py new file mode 100644 index 0000000..bfb5f91 --- /dev/null +++ b/video_utils.py @@ -0,0 +1,35 @@ +import cv2 +import os +from tqdm import tqdm + +def create_video_from_images(image_folder, output_video_path, frame_rate=30): + # 定义允许的图像后缀 + valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG"] + + # 获取图像文件列表 + image_files = [f for f in os.listdir(image_folder) + if os.path.splitext(f)[1] in valid_extensions] + image_files.sort() # 排序,确保按正确的顺序读取图像 + print(image_files) + if not image_files: + raise ValueError("No valid image files found in the specified folder.") + + # 读取第一张图像以获取视频尺寸 + first_image_path = os.path.join(image_folder, image_files[0]) + first_image = cv2.imread(first_image_path) + height, width, _ = first_image.shape + + # 创建视频写入对象 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 可以选择其他编码方式,如 'XVID' + video_writer = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height)) + + # 逐帧写入视频 + for image_file in tqdm(image_files): + image_path = os.path.join(image_folder, image_file) + image = cv2.imread(image_path) + video_writer.write(image) + + # 释放资源 + video_writer.release() + print(f"Video saved at {output_video_path}") +