add dump results to hf model demo

This commit is contained in:
rentainhe
2024-08-31 20:40:59 +08:00
parent 4f3adf3222
commit a99354bb25
2 changed files with 72 additions and 15 deletions

View File

@@ -22,11 +22,12 @@ from sam2.sam2_image_predictor import SAM2ImagePredictor
Hyper parameters
"""
API_TOKEN = "Your API token"
TEXT_PROMPT = "car"
TEXT_PROMPT = "car . building ."
IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
SAM2_MODEL_CONFIG = "sam2_hiera_l.yaml"
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = Path("outputs/grounded_sam2_gd1.5_demo")
DUMP_JSON_RESULTS = True
@@ -79,7 +80,7 @@ Init SAM 2 Model and Predict Mask with Box Prompt
# environment settings
# use bfloat16
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
@@ -89,7 +90,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
# build SAM2 image predictor
sam2_checkpoint = SAM2_CHECKPOINT
model_cfg = SAM2_MODEL_CONFIG
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=DEVICE)
sam2_predictor = SAM2ImagePredictor(sam2_model)
image = Image.open(img_path)
@@ -160,6 +161,8 @@ if DUMP_JSON_RESULTS:
input_boxes = input_boxes.tolist()
scores = scores.tolist()
# FIXME: class_names should be a list of strings without spaces
class_names = [class_name.strip() for class_name in class_names]
# save the results in standard format
results = {
"image_path": img_path,