feat: add grounded_sam2_tracking_demo_with_continuous_id.py and test data

This commit is contained in:
bd8090
2024-08-09 02:33:24 +02:00
parent 04ad096725
commit df626551c4
45 changed files with 434 additions and 13 deletions

2
.gitignore vendored
View File

@@ -9,7 +9,7 @@ build/*
_C.* _C.*
outputs/* outputs/*
checkpoints/*.pt checkpoints/*.pt
*test*
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

BIN
assets/tracking_video.mp4 Normal file

Binary file not shown.

BIN
assets/zebra.mp4 Normal file

Binary file not shown.

BIN
assets/zebra_output.mp4 Normal file

Binary file not shown.

View File

@@ -3,7 +3,7 @@ import torch
import numpy as np import numpy as np
import supervision as sv import supervision as sv
from supervision.draw.color import ColorPalette from supervision.draw.color import ColorPalette
from supervision_utils import CUSTOM_COLOR_MAP from utils.supervision_utils import CUSTOM_COLOR_MAP
from PIL import Image from PIL import Image
from sam2.build_sam import build_sam2 from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor

View File

@@ -7,8 +7,8 @@ from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from track_utils import sample_points_from_masks from utils.track_utils import sample_points_from_masks
from video_utils import create_video_from_images from utils.video_utils import create_video_from_images
""" """
@@ -40,10 +40,11 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "children." text = "car."
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg` # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/bedroom"
video_dir = "notebooks/videos/car"
# scan all the JPEG frame names in this directory # scan all the JPEG frame names in this directory
frame_names = [ frame_names = [

View File

@@ -10,8 +10,8 @@ from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from track_utils import sample_points_from_masks from utils.track_utils import sample_points_from_masks
from video_utils import create_video_from_images from utils.video_utils import create_video_from_images
""" """
Hyperparam for Ground and Tracking Hyperparam for Ground and Tracking

View File

@@ -17,8 +17,8 @@ from tqdm import tqdm
from PIL import Image from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from track_utils import sample_points_from_masks from utils.track_utils import sample_points_from_masks
from video_utils import create_video_from_images from utils.video_utils import create_video_from_images
""" """
Hyperparam for Ground and Tracking Hyperparam for Ground and Tracking

View File

@@ -0,0 +1,201 @@
import os
import cv2
import torch
import numpy as np
import supervision as sv
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionatyModel, ObjectInfo
import json
import copy
"""
Step 1: Environment settings and model initialization
"""
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device",device)
video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
image_predictor = SAM2ImagePredictor(sam2_image_model)
# init grounding dino model from huggingface
model_id = "IDEA-Research/grounding-dino-tiny"
processor = AutoProcessor.from_pretrained(model_id)
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
# setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot
text = "car."
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/car"
# 'output_dir' is the directory to save the annotated frames
output_dir = "./outputs"
# 'output_video_path' is the path to save the final video
output_video_path = "./outputs/output.mp4"
# create the output directory
CommonUtils.creat_dirs(output_dir)
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
result_dir = os.path.join(output_dir, "result")
CommonUtils.creat_dirs(mask_data_dir)
CommonUtils.creat_dirs(json_data_dir)
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# init video predictor state
inference_state = video_predictor.init_state(video_path=video_dir)
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
step = 10 # the step to sample frames for groundedDino predictor
sam2_masks = MaskDictionatyModel()
PROMPT_TYPE_FOR_VIDEO = "mask" # box, mask or point
objects_count = 0
"""
Step 2: Prompt Grounding DINO and SAM image predictor to get the box and mask for all frames
"""
print("Total frames:", len(frame_names))
for start_frame_idx in range(0, len(frame_names), step):
# prompt grounding dino to get the box coordinates on specific frame
print("start_frame_idx", start_frame_idx)
# continue
img_path = os.path.join(video_dir, frame_names[start_frame_idx])
image = Image.open(img_path)
image_base_name = frame_names[start_frame_idx].split(".")[0]
mask_dict = MaskDictionatyModel(promote_type = PROMPT_TYPE_FOR_VIDEO, mask_name = f"mask_{image_base_name}.npy")
# run Grounding DINO on the image
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = grounding_model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=0.25,
text_threshold=0.25,
target_sizes=[image.size[::-1]]
)
# prompt SAM image predictor to get the mask for the object
image_predictor.set_image(np.array(image.convert("RGB")))
# process the detection results
input_boxes = results[0]["boxes"] # .cpu().numpy()
# print("results[0]",results[0])
OBJECTS = results[0]["labels"]
# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
"""
Step 3: Register each object's positive points to video predictor
"""
# If you are using point prompts, we uniformly sample positive points based on the mask
if mask_dict.promote_type == "mask":
mask_dict.add_new_frame_annotation(mask_list=torch.tensor(masks).to(device), box_list=torch.tensor(input_boxes), label_list=OBJECTS)
else:
raise NotImplementedError("SAM 2 video predictor only support mask prompts")
"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
print("objects_count", objects_count)
video_predictor.reset_state(inference_state)
if len(mask_dict.labels) == 0:
print("No object detected in the frame, skip the frame {}".format(start_frame_idx))
continue
video_predictor.reset_state(inference_state)
for object_id, object_info in mask_dict.labels.items():
frame_idx, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
inference_state,
start_frame_idx,
object_id,
object_info.mask,
)
video_segments = {} # output the following {step} frames tracking masks
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state, max_frame_num_to_track=step, start_frame_idx=start_frame_idx):
frame_masks = MaskDictionatyModel()
for i, out_obj_id in enumerate(out_obj_ids):
out_mask = (out_mask_logits[i] > 0.0) # .cpu().numpy()
object_info = ObjectInfo(instance_id = out_obj_id, mask = out_mask[0], class_name = mask_dict.get_target_class_name(out_obj_id))
object_info.update_box()
frame_masks.labels[out_obj_id] = object_info
image_base_name = frame_names[out_frame_idx].split(".")[0]
frame_masks.mask_name = f"mask_{image_base_name}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
video_segments[out_frame_idx] = frame_masks
sam2_masks = copy.deepcopy(frame_masks)
print("video_segments:", len(video_segments))
"""
Step 5: save the tracking masks and json files
"""
for frame_idx, frame_masks_info in video_segments.items():
mask = frame_masks_info.labels
mask_img = torch.zeros(frame_masks_info.mask_height, frame_masks_info.mask_width)
for obj_id, obj_info in mask.items():
mask_img[obj_info.mask == True] = obj_id
mask_img = mask_img.numpy().astype(np.uint16)
np.save(os.path.join(mask_data_dir, frame_masks_info.mask_name), mask_img)
json_data = frame_masks_info.to_dict()
json_data_path = os.path.join(json_data_dir, frame_masks_info.mask_name.replace(".npy", ".json"))
with open(json_data_path, "w") as f:
json.dump(json_data, f)
"""
Step 6: Draw the results and save the video
"""
CommonUtils.draw_masks_and_box(video_dir, mask_data_dir, json_data_dir, result_dir)
create_video_from_images(result_dir, output_video_path, frame_rate=15)

View File

@@ -14,8 +14,8 @@ import supervision as sv
from PIL import Image from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2 from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2.sam2_image_predictor import SAM2ImagePredictor
from track_utils import sample_points_from_masks from utils.track_utils import sample_points_from_masks
from video_utils import create_video_from_images from utils.video_utils import create_video_from_images
""" """

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 156 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 195 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 174 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

77
utils/common_utils.py Normal file
View File

@@ -0,0 +1,77 @@
import os
import json
import cv2
import numpy as np
from dataclasses import dataclass
import random
class CommonUtils:
@staticmethod
def creat_dirs(path):
"""
Ensure the given path exists. If it does not exist, create it using os.makedirs.
:param path: The directory path to check or create.
"""
try:
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
print(f"Path '{path}' did not exist and has been created.")
else:
print(f"Path '{path}' already exists.")
except Exception as e:
print(f"An error occurred while creating the path: {e}")
@staticmethod
def draw_masks_and_box(raw_image_path, mask_path, json_path, output_path):
CommonUtils.creat_dirs(output_path)
raw_image_name_list = os.listdir(raw_image_path)
raw_image_name_list.sort()
for raw_image_name in raw_image_name_list:
image_path = os.path.join(raw_image_path, raw_image_name)
image = cv2.imread(image_path)
if image is None:
raise FileNotFoundError("Image file not found.")
# load mask
mask_npy_path = os.path.join(mask_path, "mask_"+raw_image_name.split(".")[0]+".npy")
mask = np.load(mask_npy_path)
# color map
unique_ids = np.unique(mask)
colors = {uid: CommonUtils.random_color() for uid in unique_ids}
colors[0] = (0, 0, 0) # background color
# apply mask to image
colored_mask = np.zeros_like(image)
for uid in unique_ids:
colored_mask[mask == uid] = colors[uid]
alpha = 0.5 # 调整 alpha 值以改变透明度
output_image = cv2.addWeighted(image, 1 - alpha, colored_mask, alpha, 0)
file_path = os.path.join(json_path, "mask_"+raw_image_name.split(".")[0]+".json")
with open(file_path, 'r') as file:
json_data = json.load(file)
# Draw bounding boxes and labels
for obj_id, obj_item in json_data["labels"].items():
# Extract data from JSON
x1, y1, x2, y2 = obj_item["x1"], obj_item["y1"], obj_item["x2"], obj_item["y2"]
instance_id = obj_item["instance_id"]
class_name = obj_item["class_name"]
# Draw rectangle
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Put text
label = f"{instance_id}: {class_name}"
cv2.putText(output_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
# Save the modified image
output_image_path = os.path.join(output_path, raw_image_name)
cv2.imwrite(output_image_path, output_image)
print(f"Annotated image saved as {output_image_path}")
@staticmethod
def random_color():
"""random color generator"""
return (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))

View File

@@ -0,0 +1,142 @@
import numpy as np
import json
import torch
import copy
import os
import cv2
from dataclasses import dataclass, field
@dataclass
class MaskDictionatyModel:
mask_name:str = ""
mask_height: int = 1080
mask_width:int = 1920
promote_type:str = "mask"
labels:dict = field(default_factory=dict)
def add_new_frame_annotation(self, mask_list, box_list, label_list, background_value = 0):
mask_img = torch.zeros(mask_list.shape[-2:])
anno_2d = {}
for idx, (mask, box, label) in enumerate(zip(mask_list, box_list, label_list)):
final_index = background_value + idx + 1
if mask.shape[0] != mask_img.shape[0] or mask.shape[1] != mask_img.shape[1]:
raise ValueError("The mask shape should be the same as the mask_img shape.")
# mask = mask
mask_img[mask == True] = final_index
# print("label", label)
name = label
box = box # .numpy().tolist()
new_annotation = ObjectInfo(instance_id = final_index, mask = mask, class_name = name, x1 = box[0], y1 = box[1], x2 = box[2], y2 = box[3])
anno_2d[final_index] = new_annotation
# np.save(os.path.join(output_dir, output_file_name), mask_img.numpy().astype(np.uint16))
self.mask_height = mask_img.shape[0]
self.mask_width = mask_img.shape[1]
self.labels = anno_2d
def update_masks(self, tracking_annotation_dict, iou_threshold=0.8, objects_count=0):
updated_masks = {}
for seg_obj_id, seg_mask in self.labels.items(): # tracking_masks
flag = 0
new_mask_copy = ObjectInfo()
if seg_mask.mask.sum() == 0:
continue
for object_id, object_info in tracking_annotation_dict.labels.items(): # grounded_sam masks
iou = self.calculate_iou(seg_mask.mask, object_info.mask) # tensor, numpy
# print("iou", iou)
if iou > iou_threshold:
flag = object_info.instance_id
new_mask_copy.mask = seg_mask.mask
new_mask_copy.instance_id = object_info.instance_id
new_mask_copy.class_name = seg_mask.class_name
break
if not flag:
objects_count += 1
flag = objects_count
new_mask_copy.instance_id = objects_count
new_mask_copy.mask = seg_mask.mask
new_mask_copy.class_name = seg_mask.class_name
updated_masks[flag] = new_mask_copy
self.labels = updated_masks
return objects_count
def get_target_class_name(self, instance_id):
return self.labels[instance_id].class_name
@staticmethod
def calculate_iou(mask1, mask2):
# Convert masks to float tensors for calculations
mask1 = mask1.to(torch.float32)
mask2 = mask2.to(torch.float32)
# Calculate intersection and union
intersection = (mask1 * mask2).sum()
union = mask1.sum() + mask2.sum() - intersection
# Calculate IoU
iou = intersection / union
return iou
def to_dict(self):
return {
"mask_name": self.mask_name,
"mask_height": self.mask_height,
"mask_width": self.mask_width,
"promote_type": self.promote_type,
"labels": {k: v.to_dict() for k, v in self.labels.items()}
}
@dataclass
class ObjectInfo:
instance_id:int = 0
mask: any = None
class_name:str = ""
x1:int = 0
y1:int = 0
x2:int = 0
y2:int = 0
logit:float = 0.0
def get_mask(self):
return self.mask
def get_id(self):
return self.instance_id
def update_box(self):
# 找到所有非零值的索引
nonzero_indices = torch.nonzero(self.mask)
# 如果没有非零值,返回一个空的边界框
if nonzero_indices.size(0) == 0:
# print("nonzero_indices", nonzero_indices)
return []
# 计算最小和最大索引
y_min, x_min = torch.min(nonzero_indices, dim=0)[0]
y_max, x_max = torch.max(nonzero_indices, dim=0)[0]
# 创建边界框 [x_min, y_min, x_max, y_max]
bbox = [x_min.item(), y_min.item(), x_max.item(), y_max.item()]
self.x1 = bbox[0]
self.y1 = bbox[1]
self.x2 = bbox[2]
self.y2 = bbox[3]
def to_dict(self):
return {
"instance_id": self.instance_id,
"class_name": self.class_name,
"x1": self.x1,
"y1": self.y1,
"x2": self.x2,
"y2": self.y2,
"logit": self.logit
}