feat: add grounded_sam2_tracking_demo_with_continuous_id.py and test data

This commit is contained in:
bd8090
2024-08-09 02:33:24 +02:00
parent 04ad096725
commit df626551c4
45 changed files with 434 additions and 13 deletions

77
utils/common_utils.py Normal file
View File

@@ -0,0 +1,77 @@
import os
import json
import cv2
import numpy as np
from dataclasses import dataclass
import random
class CommonUtils:
@staticmethod
def creat_dirs(path):
"""
Ensure the given path exists. If it does not exist, create it using os.makedirs.
:param path: The directory path to check or create.
"""
try:
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
print(f"Path '{path}' did not exist and has been created.")
else:
print(f"Path '{path}' already exists.")
except Exception as e:
print(f"An error occurred while creating the path: {e}")
@staticmethod
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
CommonUtils.creat_dirs(output_path)
raw_image_name_list = os.listdir(raw_image_path)
raw_image_name_list.sort()
for raw_image_name in raw_image_name_list:
image_path = os.path.join(raw_image_path, raw_image_name)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError("Image file not found.")
# load mask
mask_npy_path = os.path.join(mask_path, "mask_"+raw_image_name.split(".")[0]+".npy")
mask = np.load(mask_npy_path)
# color map
unique_ids = np.unique(mask)
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
colors[0] = (0, 0, 0) # background color
# apply mask to image
colored_mask = np.zeros_like(image)
for uid in unique_ids:
colored_mask[mask == uid] = colors[uid]
alpha = 0.5 # 调整 alpha 值以改变透明度
output_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
file_path = os.path.join(json_path, "mask_"+raw_image_name.split(".")[0]+".json")
with open(file_path, 'r') as file:
json_data = json.load(file)
# Draw bounding boxes and labels
for obj_id, obj_item in json_data["labels"].items():
# Extract data from JSON
x1, y1, x2, y2 = obj_item["x1"], obj_item["y1"], obj_item["x2"], obj_item["y2"]
instance_id = obj_item["instance_id"]
class_name = obj_item["class_name"]
# Draw rectangle
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Put text
label = f"{instance_id}: {class_name}"
cv2.putText(output_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
# Save the modified image
output_image_path = os.path.join(output_path, raw_image_name)
cv2.imwrite(output_image_path, output_image)
print(f"Annotated image saved as {output_image_path}")
@staticmethod
def random_color():
"""random color generator"""
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

View File

@@ -0,0 +1,142 @@
import numpy as np
import json
import torch
import copy
import os
import cv2
from dataclasses import dataclass, field
@dataclass
class MaskDictionatyModel:
mask_name:str = ""
mask_height: int = 1080
mask_width:int = 1920
promote_type:str = "mask"
labels:dict = field(default_factory=dict)
def add_new_frame_annotation(self, mask_list, box_list, label_list, background_value = 0):
mask_img = torch.zeros(mask_list.shape[-2:])
anno_2d = {}
for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
final_index = background_value + idx + 1
if mask.shape[0] != mask_img.shape[0] or mask.shape[1] != mask_img.shape[1]:
raise ValueError("The mask shape should be the same as the mask_img shape.")
# mask = mask
mask_img[mask == True] = final_index
# print("label", label)
name = label
box = box # .numpy().tolist()
new_annotation = ObjectInfo(instance_id = final_index, mask = mask, class_name = name, x1 = box[0], y1 = box[1], x2 = box[2], y2 = box[3])
anno_2d[final_index] = new_annotation
# np.save(os.path.join(output_dir, output_file_name), mask_img.numpy().astype(np.uint16))
self.mask_height = mask_img.shape[0]
self.mask_width = mask_img.shape[1]
self.labels = anno_2d
def update_masks(self, tracking_annotation_dict, iou_threshold=0.8, objects_count=0):
updated_masks = {}
for seg_obj_id, seg_mask in self.labels.items(): # tracking_masks
flag = 0
new_mask_copy = ObjectInfo()
if seg_mask.mask.sum() == 0:
continue
for object_id, object_info in tracking_annotation_dict.labels.items(): # grounded_sam masks
iou = self.calculate_iou(seg_mask.mask, object_info.mask) # tensor, numpy
# print("iou", iou)
if iou > iou_threshold:
flag = object_info.instance_id
new_mask_copy.mask = seg_mask.mask
new_mask_copy.instance_id = object_info.instance_id
new_mask_copy.class_name = seg_mask.class_name
break
if not flag:
objects_count += 1
flag = objects_count
new_mask_copy.instance_id = objects_count
new_mask_copy.mask = seg_mask.mask
new_mask_copy.class_name = seg_mask.class_name
updated_masks[flag] = new_mask_copy
self.labels = updated_masks
return objects_count
def get_target_class_name(self, instance_id):
return self.labels[instance_id].class_name
@staticmethod
def calculate_iou(mask1, mask2):
# Convert masks to float tensors for calculations
mask1 = mask1.to(torch.float32)
mask2 = mask2.to(torch.float32)
# Calculate intersection and union
intersection = (mask1 * mask2).sum()
union = mask1.sum() + mask2.sum() - intersection
# Calculate IoU
iou = intersection / union
return iou
def to_dict(self):
return {
"mask_name": self.mask_name,
"mask_height": self.mask_height,
"mask_width": self.mask_width,
"promote_type": self.promote_type,
"labels": {k: v.to_dict() for k, v in self.labels.items()}
}
@dataclass
class ObjectInfo:
instance_id:int = 0
mask: any = None
class_name:str = ""
x1:int = 0
y1:int = 0
x2:int = 0
y2:int = 0
logit:float = 0.0
def get_mask(self):
return self.mask
def get_id(self):
return self.instance_id
def update_box(self):
# 找到所有非零值的索引
nonzero_indices = torch.nonzero(self.mask)
# 如果没有非零值,返回一个空的边界框
if nonzero_indices.size(0) == 0:
# print("nonzero_indices", nonzero_indices)
return []
# 计算最小和最大索引
y_min, x_min = torch.min(nonzero_indices, dim=0)[0]
y_max, x_max = torch.max(nonzero_indices, dim=0)[0]
# 创建边界框 [x_min, y_min, x_max, y_max]
bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
self.x1 = bbox[0]
self.y1 = bbox[1]
self.x2 = bbox[2]
self.y2 = bbox[3]
def to_dict(self):
return {
"instance_id": self.instance_id,
"class_name": self.class_name,
"x1": self.x1,
"y1": self.y1,
"x2": self.x2,
"y2": self.y2,
"logit": self.logit
}

View File

@@ -0,0 +1,18 @@
CUSTOM_COLOR_MAP = [
"#e6194b",
"#3cb44b",
"#ffe119",
"#0082c8",
"#f58231",
"#911eb4",
"#46f0f0",
"#f032e6",
"#d2f53c",
"#fabebe",
"#008080",
"#e6beff",
"#aa6e28",
"#fffac8",
"#800000",
"#aaffc3",
]

41
utils/track_utils.py Normal file
View File

@@ -0,0 +1,41 @@
import numpy as np
def sample_points_from_masks(masks, num_points):
"""
sample points from masks and return its absolute coordinates
Args:
masks: np.array with shape (n, h, w)
num_points: int
Returns:
points: np.array with shape (n, points, 2)
"""
n, h, w = masks.shape
points = []
for i in range(n):
# find the valid mask points
indices = np.argwhere(masks[i] == 1)
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
# we should convert it to (x, y)
indices = indices[:, ::-1] # (num_points, [y x]) to (num_points, [x y])
# import pdb; pdb.set_trace()
if len(indices) == 0:
# if there are no valid points, append an empty array
points.append(np.array([]))
continue
# resampling if there's not enough points
if len(indices) < num_points:
sampled_indices = np.random.choice(len(indices), num_points, replace=True)
else:
sampled_indices = np.random.choice(len(indices), num_points, replace=False)
sampled_points = indices[sampled_indices]
points.append(sampled_points)
# convert to np.array
points = np.array(points, dtype=np.float32)
return points

35
utils/video_utils.py Normal file
View File

@@ -0,0 +1,35 @@
import cv2
import os
from tqdm import tqdm
def create_video_from_images(image_folder, output_video_path, frame_rate=25):
# define valid extension
valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG"]
# get all image files in the folder
image_files = [f for f in os.listdir(image_folder)
if os.path.splitext(f)[1] in valid_extensions]
image_files.sort() # sort the files in alphabetical order
print(image_files)
if not image_files:
raise ValueError("No valid image files found in the specified folder.")
# load the first image to get the dimensions of the video
first_image_path = os.path.join(image_folder, image_files[0])
first_image = cv2.imread(first_image_path)
height, width, _ = first_image.shape
# create a video writer
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # codec for saving the video
video_writer = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height))
# write each image to the video
for image_file in tqdm(image_files):
image_path = os.path.join(image_folder, image_file)
image = cv2.imread(image_path)
video_writer.write(image)
# source release
video_writer.release()
print(f"Video saved at {output_video_path}")