Update the demo.py

This commit is contained in:
Zhongyu Jiang
2024-11-25 12:59:23 -08:00
parent 4e965684ba
commit 65f5767f09
2 changed files with 7 additions and 2 deletions

3
.gitignore vendored
View File

@@ -161,3 +161,6 @@ visualization/*
# .DS_Store # .DS_Store
.DS_Store .DS_Store
# For Testing
demo/

View File

@@ -45,13 +45,15 @@ def main(args):
frames_or_path = prepare_frames_or_path(args.video_path) frames_or_path = prepare_frames_or_path(args.video_path)
prompts = load_txt(args.txt_path) prompts = load_txt(args.txt_path)
frame_rate = 30
if args.save_to_video: if args.save_to_video:
if osp.isdir(args.video_path): if osp.isdir(args.video_path):
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")]) frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith((".jpg", ".jpeg", ".JPG", ".JPEG"))])
loaded_frames = [cv2.imread(frame_path) for frame_path in frames] loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
height, width = loaded_frames[0].shape[:2] height, width = loaded_frames[0].shape[:2]
else: else:
cap = cv2.VideoCapture(args.video_path) cap = cv2.VideoCapture(args.video_path)
frame_rate = cap.get(cv2.CAP_PROP_FPS)
loaded_frames = [] loaded_frames = []
while True: while True:
ret, frame = cap.read() ret, frame = cap.read()
@@ -65,7 +67,7 @@ def main(args):
raise ValueError("No frames were loaded from the video.") raise ValueError("No frames were loaded from the video.")
fourcc = cv2.VideoWriter_fourcc(*'mp4v') fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height)) out = cv2.VideoWriter(args.video_output_path, fourcc, frame_rate, (width, height))
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
state = predictor.init_state(frames_or_path, offload_video_to_cpu=True) state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)