update visualization func
This commit is contained in:
@@ -29,7 +29,7 @@ if torch.cuda.get_device_properties(0).major >= 8:
|
|||||||
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
||||||
model_cfg = "sam2_hiera_l.yaml"
|
model_cfg = "sam2_hiera_l.yaml"
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print("device",device)
|
print("device", device)
|
||||||
|
|
||||||
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
||||||
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
||||||
@@ -189,10 +189,9 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
json.dump(json_data, f)
|
json.dump(json_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 6: Draw the results and save the video
|
Step 6: Draw the results and save the video
|
||||||
"""
|
"""
|
||||||
CommonUtils.draw_masks_and_box(video_dir, mask_data_dir, json_data_dir, result_dir)
|
CommonUtils.draw_masks_and_box_with_supervision(video_dir, mask_data_dir, json_data_dir, result_dir)
|
||||||
|
|
||||||
create_video_from_images(result_dir, output_video_path, frame_rate=30)
|
create_video_from_images(result_dir, output_video_path, frame_rate=30)
|
@@ -3,6 +3,7 @@ import json
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
import supervision as sv
|
||||||
import random
|
import random
|
||||||
|
|
||||||
class CommonUtils:
|
class CommonUtils:
|
||||||
@@ -21,7 +22,90 @@ class CommonUtils:
|
|||||||
print(f"Path '{path}' already exists.")
|
print(f"Path '{path}' already exists.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An error occurred while creating the path: {e}")
|
print(f"An error occurred while creating the path: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def draw_masks_and_box_with_supervision(raw_image_path, mask_path, json_path, output_path):
|
||||||
|
CommonUtils.creat_dirs(output_path)
|
||||||
|
raw_image_name_list = os.listdir(raw_image_path)
|
||||||
|
raw_image_name_list.sort()
|
||||||
|
for raw_image_name in raw_image_name_list:
|
||||||
|
image_path = os.path.join(raw_image_path, raw_image_name)
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
raise FileNotFoundError("Image file not found.")
|
||||||
|
# load mask
|
||||||
|
mask_npy_path = os.path.join(mask_path, "mask_"+raw_image_name.split(".")[0]+".npy")
|
||||||
|
mask = np.load(mask_npy_path)
|
||||||
|
# color map
|
||||||
|
unique_ids = np.unique(mask)
|
||||||
|
|
||||||
|
# get each mask from unique mask file
|
||||||
|
all_object_masks = []
|
||||||
|
for uid in unique_ids:
|
||||||
|
if uid == 0: # skip background id
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
object_mask = (mask == uid)
|
||||||
|
all_object_masks.append(object_mask[None])
|
||||||
|
|
||||||
|
# get n masks: (n, h, w)
|
||||||
|
all_object_masks = np.concatenate(all_object_masks, axis=0)
|
||||||
|
|
||||||
|
# load box information
|
||||||
|
file_path = os.path.join(json_path, "mask_"+raw_image_name.split(".")[0]+".json")
|
||||||
|
|
||||||
|
all_object_boxes = []
|
||||||
|
all_object_ids = []
|
||||||
|
all_class_names = []
|
||||||
|
object_id_to_name = {}
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
json_data = json.load(file)
|
||||||
|
for obj_id, obj_item in json_data["labels"].items():
|
||||||
|
# box id
|
||||||
|
instance_id = obj_item["instance_id"]
|
||||||
|
if instance_id not in unique_ids: # not a valid box
|
||||||
|
continue
|
||||||
|
# box coordinates
|
||||||
|
x1, y1, x2, y2 = obj_item["x1"], obj_item["y1"], obj_item["x2"], obj_item["y2"]
|
||||||
|
all_object_boxes.append([x1, y1, x2, y2])
|
||||||
|
# box name
|
||||||
|
class_name = obj_item["class_name"]
|
||||||
|
|
||||||
|
# build id list and id2name mapping
|
||||||
|
all_object_ids.append(instance_id)
|
||||||
|
all_class_names.append(class_name)
|
||||||
|
object_id_to_name[instance_id] = class_name
|
||||||
|
|
||||||
|
# Adjust object id and boxes to ascending order
|
||||||
|
paired_id_and_box = zip(all_object_ids, all_object_boxes)
|
||||||
|
sorted_pair = sorted(paired_id_and_box, key=lambda pair: pair[0])
|
||||||
|
|
||||||
|
# Because we get the mask data as ascending order, so we also need to ascend box and ids
|
||||||
|
all_object_ids = [pair[0] for pair in sorted_pair]
|
||||||
|
all_object_boxes = [pair[1] for pair in sorted_pair]
|
||||||
|
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=np.array(all_object_boxes),
|
||||||
|
mask=all_object_masks,
|
||||||
|
class_id=np.array(all_object_ids, dtype=np.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
# custom label to show both id and class name
|
||||||
|
labels = [
|
||||||
|
f"{instance_id}: {class_name}" for instance_id, class_name in zip(all_object_ids, all_class_names)
|
||||||
|
]
|
||||||
|
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections)
|
||||||
|
label_annotator = sv.LabelAnnotator()
|
||||||
|
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=labels)
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
||||||
|
|
||||||
|
output_image_path = os.path.join(output_path, raw_image_name)
|
||||||
|
cv2.imwrite(output_image_path, annotated_frame)
|
||||||
|
print(f"Annotated image saved as {output_image_path}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
|
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
|
||||||
CommonUtils.creat_dirs(output_path)
|
CommonUtils.creat_dirs(output_path)
|
||||||
@@ -40,7 +124,7 @@ class CommonUtils:
|
|||||||
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
|
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
|
||||||
colors[0] = (0, 0, 0) # background color
|
colors[0] = (0, 0, 0) # background color
|
||||||
|
|
||||||
# apply mask to image
|
# apply mask to image in RBG channels
|
||||||
colored_mask = np.zeros_like(image)
|
colored_mask = np.zeros_like(image)
|
||||||
for uid in unique_ids:
|
for uid in unique_ids:
|
||||||
colored_mask[mask == uid] = colors[uid]
|
colored_mask[mask == uid] = colors[uid]
|
||||||
|
Reference in New Issue
Block a user