From 65f5767f0936446d9ce8bebf963e58e9ffb3c82e Mon Sep 17 00:00:00 2001 From: Zhongyu Jiang Date: Mon, 25 Nov 2024 12:59:23 -0800 Subject: [PATCH] Update the demo.py --- .gitignore | 3 +++ scripts/demo.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index e1d94cd..47c87f3 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,6 @@ visualization/* # .DS_Store .DS_Store + +# For Testing +demo/ diff --git a/scripts/demo.py b/scripts/demo.py index e2ded94..4a1c944 100644 --- a/scripts/demo.py +++ b/scripts/demo.py @@ -45,13 +45,15 @@ def main(args): frames_or_path = prepare_frames_or_path(args.video_path) prompts = load_txt(args.txt_path) + frame_rate = 30 if args.save_to_video: if osp.isdir(args.video_path): - frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")]) + frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith((".jpg", ".jpeg", ".JPG", ".JPEG"))]) loaded_frames = [cv2.imread(frame_path) for frame_path in frames] height, width = loaded_frames[0].shape[:2] else: cap = cv2.VideoCapture(args.video_path) + frame_rate = cap.get(cv2.CAP_PROP_FPS) loaded_frames = [] while True: ret, frame = cap.read() @@ -65,7 +67,7 @@ def main(args): raise ValueError("No frames were loaded from the video.") fourcc = cv2.VideoWriter_fourcc(*'mp4v') - out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height)) + out = cv2.VideoWriter(args.video_output_path, fourcc, frame_rate, (width, height)) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)