update visualization func

This commit is contained in:
rentainhe
2024-08-09 19:14:20 +08:00
parent ccacb31e59
commit cabbad473b
2 changed files with 88 additions and 5 deletions

View File

@@ -29,7 +29,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device",device)
print("device", device)
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
@@ -189,10 +189,9 @@ for start_frame_idx in range(0, len(frame_names), step):
json.dump(json_data, f)
"""
Step 6: Draw the results and save the video
"""
CommonUtils.draw_masks_and_box(video_dir, mask_data_dir, json_data_dir, result_dir)
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
create_video_from_images(result_dir, output_video_path, frame_rate=30)