Update the demo.py
This commit is contained in:
@@ -45,13 +45,15 @@ def main(args):
|
||||
frames_or_path = prepare_frames_or_path(args.video_path)
|
||||
prompts = load_txt(args.txt_path)
|
||||
|
||||
frame_rate = 30
|
||||
if args.save_to_video:
|
||||
if osp.isdir(args.video_path):
|
||||
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
|
||||
frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith((".jpg", ".jpeg", ".JPG", ".JPEG"))])
|
||||
loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
|
||||
height, width = loaded_frames[0].shape[:2]
|
||||
else:
|
||||
cap = cv2.VideoCapture(args.video_path)
|
||||
frame_rate = cap.get(cv2.CAP_PROP_FPS)
|
||||
loaded_frames = []
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
@@ -65,7 +67,7 @@ def main(args):
|
||||
raise ValueError("No frames were loaded from the video.")
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(args.video_output_path, fourcc, 30, (width, height))
|
||||
out = cv2.VideoWriter(args.video_output_path, fourcc, frame_rate, (width, height))
|
||||
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
|
||||
state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
|
||||
|
Reference in New Issue
Block a user