feat: add grounded_sam2_tracking_demo_with_continuous_id.py and test data
2
.gitignore
vendored
@@ -9,7 +9,7 @@ build/*
|
||||
_C.*
|
||||
outputs/*
|
||||
checkpoints/*.pt
|
||||
|
||||
*test*
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
BIN
assets/tracking_video.mp4
Normal file
BIN
assets/zebra.mp4
Normal file
BIN
assets/zebra_output.mp4
Normal file
@@ -3,7 +3,7 @@ import torch
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from supervision.draw.color import ColorPalette
|
||||
from supervision_utils import CUSTOM_COLOR_MAP
|
||||
from utils.supervision_utils import CUSTOM_COLOR_MAP
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
@@ -7,8 +7,8 @@ from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
from track_utils import sample_points_from_masks
|
||||
from video_utils import create_video_from_images
|
||||
from utils.track_utils import sample_points_from_masks
|
||||
from utils.video_utils import create_video_from_images
|
||||
|
||||
|
||||
"""
|
||||
@@ -40,10 +40,11 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
|
||||
|
||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||
# VERY important: text queries need to be lowercased + end with a dot
|
||||
text = "children."
|
||||
text = "car."
|
||||
|
||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||
video_dir = "notebooks/videos/bedroom"
|
||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||
|
||||
video_dir = "notebooks/videos/car"
|
||||
|
||||
# scan all the JPEG frame names in this directory
|
||||
frame_names = [
|
||||
|
@@ -10,8 +10,8 @@ from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
from track_utils import sample_points_from_masks
|
||||
from video_utils import create_video_from_images
|
||||
from utils.track_utils import sample_points_from_masks
|
||||
from utils.video_utils import create_video_from_images
|
||||
|
||||
"""
|
||||
Hyperparam for Ground and Tracking
|
||||
|
@@ -17,8 +17,8 @@ from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from track_utils import sample_points_from_masks
|
||||
from video_utils import create_video_from_images
|
||||
from utils.track_utils import sample_points_from_masks
|
||||
from utils.video_utils import create_video_from_images
|
||||
|
||||
"""
|
||||
Hyperparam for Ground and Tracking
|
||||
|
201
grounded_sam2_tracking_demo_with_continuous_id.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
from utils.track_utils import sample_points_from_masks
|
||||
from utils.video_utils import create_video_from_images
|
||||
from utils.common_utils import CommonUtils
|
||||
from utils.mask_dictionary_model import MaskDictionatyModel, ObjectInfo
|
||||
import json
|
||||
import copy
|
||||
|
||||
"""
|
||||
Step 1: Environment settings and model initialization
|
||||
"""
|
||||
# use bfloat16 for the entire notebook
|
||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# init sam image predictor and video predictor model
|
||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||
model_cfg = "sam2_hiera_l.yaml"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print("device",device)
|
||||
|
||||
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
||||
image_predictor = SAM2ImagePredictor(sam2_image_model)
|
||||
|
||||
|
||||
# init grounding dino model from huggingface
|
||||
model_id = "IDEA-Research/grounding-dino-tiny"
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
|
||||
|
||||
|
||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||
# VERY important: text queries need to be lowercased + end with a dot
|
||||
text = "car."
|
||||
|
||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||
video_dir = "notebooks/videos/car"
|
||||
# 'output_dir' is the directory to save the annotated frames
|
||||
output_dir = "./outputs"
|
||||
# 'output_video_path' is the path to save the final video
|
||||
output_video_path = "./outputs/output.mp4"
|
||||
# create the output directory
|
||||
CommonUtils.creat_dirs(output_dir)
|
||||
mask_data_dir = os.path.join(output_dir, "mask_data")
|
||||
json_data_dir = os.path.join(output_dir, "json_data")
|
||||
result_dir = os.path.join(output_dir, "result")
|
||||
CommonUtils.creat_dirs(mask_data_dir)
|
||||
CommonUtils.creat_dirs(json_data_dir)
|
||||
# scan all the JPEG frame names in this directory
|
||||
frame_names = [
|
||||
p for p in os.listdir(video_dir)
|
||||
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
||||
]
|
||||
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
||||
|
||||
# init video predictor state
|
||||
inference_state = video_predictor.init_state(video_path=video_dir)
|
||||
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
step = 10 # the step to sample frames for groundedDino predictor
|
||||
|
||||
sam2_masks = MaskDictionatyModel()
|
||||
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
|
||||
objects_count = 0
|
||||
|
||||
"""
|
||||
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
|
||||
"""
|
||||
print("Total frames:", len(frame_names))
|
||||
for start_frame_idx in range(0, len(frame_names), step):
|
||||
# prompt grounding dino to get the box coordinates on specific frame
|
||||
print("start_frame_idx", start_frame_idx)
|
||||
# continue
|
||||
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
|
||||
image = Image.open(img_path)
|
||||
image_base_name = frame_names[start_frame_idx].split(".")[0]
|
||||
mask_dict = MaskDictionatyModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
|
||||
|
||||
# run Grounding DINO on the image
|
||||
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = grounding_model(**inputs)
|
||||
|
||||
results = processor.post_process_grounded_object_detection(
|
||||
outputs,
|
||||
inputs.input_ids,
|
||||
box_threshold=0.25,
|
||||
text_threshold=0.25,
|
||||
target_sizes=[image.size[::-1]]
|
||||
)
|
||||
|
||||
# prompt SAM image predictor to get the mask for the object
|
||||
image_predictor.set_image(np.array(image.convert("RGB")))
|
||||
|
||||
# process the detection results
|
||||
input_boxes = results[0]["boxes"] # .cpu().numpy()
|
||||
# print("results[0]",results[0])
|
||||
OBJECTS = results[0]["labels"]
|
||||
|
||||
# prompt SAM 2 image predictor to get the mask for the object
|
||||
masks, scores, logits = image_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
# convert the mask shape to (n, H, W)
|
||||
if masks.ndim == 2:
|
||||
masks = masks[None]
|
||||
scores = scores[None]
|
||||
logits = logits[None]
|
||||
elif masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
"""
|
||||
Step 3: Register each object's positive points to video predictor
|
||||
"""
|
||||
|
||||
# If you are using point prompts, we uniformly sample positive points based on the mask
|
||||
if mask_dict.promote_type == "mask":
|
||||
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
|
||||
else:
|
||||
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
|
||||
|
||||
|
||||
"""
|
||||
Step 4: Propagate the video predictor to get the segmentation results for each frame
|
||||
"""
|
||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
||||
print("objects_count", objects_count)
|
||||
video_predictor.reset_state(inference_state)
|
||||
if len(mask_dict.labels) == 0:
|
||||
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
|
||||
continue
|
||||
video_predictor.reset_state(inference_state)
|
||||
|
||||
for object_id, object_info in mask_dict.labels.items():
|
||||
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
|
||||
inference_state,
|
||||
start_frame_idx,
|
||||
object_id,
|
||||
object_info.mask,
|
||||
)
|
||||
|
||||
video_segments = {} # output the following {step} frames tracking masks
|
||||
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
|
||||
frame_masks = MaskDictionatyModel()
|
||||
|
||||
for i, out_obj_id in enumerate(out_obj_ids):
|
||||
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
|
||||
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
|
||||
object_info.update_box()
|
||||
frame_masks.labels[out_obj_id] = object_info
|
||||
image_base_name = frame_names[out_frame_idx].split(".")[0]
|
||||
frame_masks.mask_name = f"mask_{image_base_name}.npy"
|
||||
frame_masks.mask_height = out_mask.shape[-2]
|
||||
frame_masks.mask_width = out_mask.shape[-1]
|
||||
|
||||
video_segments[out_frame_idx] = frame_masks
|
||||
sam2_masks = copy.deepcopy(frame_masks)
|
||||
|
||||
print("video_segments:", len(video_segments))
|
||||
"""
|
||||
Step 5: save the tracking masks and json files
|
||||
"""
|
||||
for frame_idx, frame_masks_info in video_segments.items():
|
||||
mask = frame_masks_info.labels
|
||||
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
|
||||
for obj_id, obj_info in mask.items():
|
||||
mask_img[obj_info.mask == True] = obj_id
|
||||
|
||||
mask_img = mask_img.numpy().astype(np.uint16)
|
||||
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
|
||||
|
||||
json_data = frame_masks_info.to_dict()
|
||||
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
|
||||
with open(json_data_path, "w") as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Step 6: Draw the results and save the video
|
||||
"""
|
||||
CommonUtils.draw_masks_and_box(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||
|
||||
create_video_from_images(result_dir, output_video_path, frame_rate=15)
|
@@ -14,8 +14,8 @@ import supervision as sv
|
||||
from PIL import Image
|
||||
from sam2.build_sam import build_sam2_video_predictor, build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
from track_utils import sample_points_from_masks
|
||||
from video_utils import create_video_from_images
|
||||
from utils.track_utils import sample_points_from_masks
|
||||
from utils.video_utils import create_video_from_images
|
||||
|
||||
|
||||
"""
|
||||
|
BIN
notebooks/videos/car/00001.jpg
Normal file
After Width: | Height: | Size: 131 KiB |
BIN
notebooks/videos/car/00002.jpg
Normal file
After Width: | Height: | Size: 156 KiB |
BIN
notebooks/videos/car/00003.jpg
Normal file
After Width: | Height: | Size: 195 KiB |
BIN
notebooks/videos/car/00004.jpg
Normal file
After Width: | Height: | Size: 174 KiB |
BIN
notebooks/videos/car/00005.jpg
Normal file
After Width: | Height: | Size: 155 KiB |
BIN
notebooks/videos/car/00006.jpg
Normal file
After Width: | Height: | Size: 120 KiB |
BIN
notebooks/videos/car/00007.jpg
Normal file
After Width: | Height: | Size: 103 KiB |
BIN
notebooks/videos/car/00008.jpg
Normal file
After Width: | Height: | Size: 94 KiB |
BIN
notebooks/videos/car/00009.jpg
Normal file
After Width: | Height: | Size: 84 KiB |
BIN
notebooks/videos/car/00010.jpg
Normal file
After Width: | Height: | Size: 77 KiB |
BIN
notebooks/videos/car/00011.jpg
Normal file
After Width: | Height: | Size: 72 KiB |
BIN
notebooks/videos/car/00012.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00013.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00014.jpg
Normal file
After Width: | Height: | Size: 69 KiB |
BIN
notebooks/videos/car/00015.jpg
Normal file
After Width: | Height: | Size: 69 KiB |
BIN
notebooks/videos/car/00016.jpg
Normal file
After Width: | Height: | Size: 69 KiB |
BIN
notebooks/videos/car/00017.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00018.jpg
Normal file
After Width: | Height: | Size: 69 KiB |
BIN
notebooks/videos/car/00019.jpg
Normal file
After Width: | Height: | Size: 69 KiB |
BIN
notebooks/videos/car/00020.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00021.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00022.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00023.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00024.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00025.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00026.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
notebooks/videos/car/00027.jpg
Normal file
After Width: | Height: | Size: 71 KiB |
BIN
notebooks/videos/car/00028.jpg
Normal file
After Width: | Height: | Size: 71 KiB |
BIN
notebooks/videos/car/00029.jpg
Normal file
After Width: | Height: | Size: 71 KiB |
BIN
notebooks/videos/car/00030.jpg
Normal file
After Width: | Height: | Size: 72 KiB |
77
utils/common_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
|
||||
class CommonUtils:
|
||||
@staticmethod
|
||||
def creat_dirs(path):
|
||||
"""
|
||||
Ensure the given path exists. If it does not exist, create it using os.makedirs.
|
||||
|
||||
:param path: The directory path to check or create.
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
print(f"Path '{path}' did not exist and has been created.")
|
||||
else:
|
||||
print(f"Path '{path}' already exists.")
|
||||
except Exception as e:
|
||||
print(f"An error occurred while creating the path: {e}")
|
||||
|
||||
@staticmethod
|
||||
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
|
||||
CommonUtils.creat_dirs(output_path)
|
||||
raw_image_name_list = os.listdir(raw_image_path)
|
||||
raw_image_name_list.sort()
|
||||
for raw_image_name in raw_image_name_list:
|
||||
image_path = os.path.join(raw_image_path, raw_image_name)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise FileNotFoundError("Image file not found.")
|
||||
# load mask
|
||||
mask_npy_path = os.path.join(mask_path, "mask_"+raw_image_name.split(".")[0]+".npy")
|
||||
mask = np.load(mask_npy_path)
|
||||
# color map
|
||||
unique_ids = np.unique(mask)
|
||||
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
|
||||
colors[0] = (0, 0, 0) # background color
|
||||
|
||||
# apply mask to image
|
||||
colored_mask = np.zeros_like(image)
|
||||
for uid in unique_ids:
|
||||
colored_mask[mask == uid] = colors[uid]
|
||||
alpha = 0.5 # 调整 alpha 值以改变透明度
|
||||
output_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
|
||||
|
||||
|
||||
file_path = os.path.join(json_path, "mask_"+raw_image_name.split(".")[0]+".json")
|
||||
with open(file_path, 'r') as file:
|
||||
json_data = json.load(file)
|
||||
# Draw bounding boxes and labels
|
||||
for obj_id, obj_item in json_data["labels"].items():
|
||||
# Extract data from JSON
|
||||
x1, y1, x2, y2 = obj_item["x1"], obj_item["y1"], obj_item["x2"], obj_item["y2"]
|
||||
instance_id = obj_item["instance_id"]
|
||||
class_name = obj_item["class_name"]
|
||||
|
||||
# Draw rectangle
|
||||
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
|
||||
# Put text
|
||||
label = f"{instance_id}: {class_name}"
|
||||
cv2.putText(output_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
|
||||
|
||||
# Save the modified image
|
||||
output_image_path = os.path.join(output_path, raw_image_name)
|
||||
cv2.imwrite(output_image_path, output_image)
|
||||
|
||||
print(f"Annotated image saved as {output_image_path}")
|
||||
|
||||
@staticmethod
|
||||
def random_color():
|
||||
"""random color generator"""
|
||||
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
|
142
utils/mask_dictionary_model.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import numpy as np
|
||||
import json
|
||||
import torch
|
||||
import copy
|
||||
import os
|
||||
import cv2
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@dataclass
|
||||
class MaskDictionatyModel:
|
||||
mask_name:str = ""
|
||||
mask_height: int = 1080
|
||||
mask_width:int = 1920
|
||||
promote_type:str = "mask"
|
||||
labels:dict = field(default_factory=dict)
|
||||
|
||||
def add_new_frame_annotation(self, mask_list, box_list, label_list, background_value = 0):
|
||||
mask_img = torch.zeros(mask_list.shape[-2:])
|
||||
anno_2d = {}
|
||||
for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
|
||||
final_index = background_value + idx + 1
|
||||
|
||||
if mask.shape[0] != mask_img.shape[0] or mask.shape[1] != mask_img.shape[1]:
|
||||
raise ValueError("The mask shape should be the same as the mask_img shape.")
|
||||
# mask = mask
|
||||
mask_img[mask == True] = final_index
|
||||
# print("label", label)
|
||||
name = label
|
||||
box = box # .numpy().tolist()
|
||||
new_annotation = ObjectInfo(instance_id = final_index, mask = mask, class_name = name, x1 = box[0], y1 = box[1], x2 = box[2], y2 = box[3])
|
||||
anno_2d[final_index] = new_annotation
|
||||
|
||||
# np.save(os.path.join(output_dir, output_file_name), mask_img.numpy().astype(np.uint16))
|
||||
self.mask_height = mask_img.shape[0]
|
||||
self.mask_width = mask_img.shape[1]
|
||||
self.labels = anno_2d
|
||||
|
||||
def update_masks(self, tracking_annotation_dict, iou_threshold=0.8, objects_count=0):
|
||||
updated_masks = {}
|
||||
|
||||
for seg_obj_id, seg_mask in self.labels.items(): # tracking_masks
|
||||
flag = 0
|
||||
new_mask_copy = ObjectInfo()
|
||||
if seg_mask.mask.sum() == 0:
|
||||
continue
|
||||
|
||||
for object_id, object_info in tracking_annotation_dict.labels.items(): # grounded_sam masks
|
||||
iou = self.calculate_iou(seg_mask.mask, object_info.mask) # tensor, numpy
|
||||
# print("iou", iou)
|
||||
if iou > iou_threshold:
|
||||
flag = object_info.instance_id
|
||||
new_mask_copy.mask = seg_mask.mask
|
||||
new_mask_copy.instance_id = object_info.instance_id
|
||||
new_mask_copy.class_name = seg_mask.class_name
|
||||
break
|
||||
|
||||
if not flag:
|
||||
objects_count += 1
|
||||
flag = objects_count
|
||||
new_mask_copy.instance_id = objects_count
|
||||
new_mask_copy.mask = seg_mask.mask
|
||||
new_mask_copy.class_name = seg_mask.class_name
|
||||
updated_masks[flag] = new_mask_copy
|
||||
self.labels = updated_masks
|
||||
return objects_count
|
||||
|
||||
def get_target_class_name(self, instance_id):
|
||||
return self.labels[instance_id].class_name
|
||||
|
||||
|
||||
@staticmethod
|
||||
def calculate_iou(mask1, mask2):
|
||||
# Convert masks to float tensors for calculations
|
||||
mask1 = mask1.to(torch.float32)
|
||||
mask2 = mask2.to(torch.float32)
|
||||
|
||||
# Calculate intersection and union
|
||||
intersection = (mask1 * mask2).sum()
|
||||
union = mask1.sum() + mask2.sum() - intersection
|
||||
|
||||
# Calculate IoU
|
||||
iou = intersection / union
|
||||
return iou
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"mask_name": self.mask_name,
|
||||
"mask_height": self.mask_height,
|
||||
"mask_width": self.mask_width,
|
||||
"promote_type": self.promote_type,
|
||||
"labels": {k: v.to_dict() for k, v in self.labels.items()}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ObjectInfo:
|
||||
instance_id:int = 0
|
||||
mask: any = None
|
||||
class_name:str = ""
|
||||
x1:int = 0
|
||||
y1:int = 0
|
||||
x2:int = 0
|
||||
y2:int = 0
|
||||
logit:float = 0.0
|
||||
|
||||
def get_mask(self):
|
||||
return self.mask
|
||||
|
||||
def get_id(self):
|
||||
return self.instance_id
|
||||
|
||||
def update_box(self):
|
||||
# 找到所有非零值的索引
|
||||
nonzero_indices = torch.nonzero(self.mask)
|
||||
|
||||
# 如果没有非零值,返回一个空的边界框
|
||||
if nonzero_indices.size(0) == 0:
|
||||
# print("nonzero_indices", nonzero_indices)
|
||||
return []
|
||||
|
||||
# 计算最小和最大索引
|
||||
y_min, x_min = torch.min(nonzero_indices, dim=0)[0]
|
||||
y_max, x_max = torch.max(nonzero_indices, dim=0)[0]
|
||||
|
||||
# 创建边界框 [x_min, y_min, x_max, y_max]
|
||||
bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
|
||||
self.x1 = bbox[0]
|
||||
self.y1 = bbox[1]
|
||||
self.x2 = bbox[2]
|
||||
self.y2 = bbox[3]
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"instance_id": self.instance_id,
|
||||
"class_name": self.class_name,
|
||||
"x1": self.x1,
|
||||
"y1": self.y1,
|
||||
"x2": self.x2,
|
||||
"y2": self.y2,
|
||||
"logit": self.logit
|
||||
}
|