11 Commits

Author SHA1 Message Date
kiennt
33303aa62f feat: Update setup for project 2025-08-14 09:27:02 +00:00
kiennt
34b17b0280 feat : Update code, new args 2025-08-14 09:26:37 +00:00
will ye
2111d9c52c Fix demos for CPU inference (#104) 2025-05-27 00:24:30 +08:00
will ye
75aaf0c3ae Change default output dir for HF demo (#105) 2025-05-27 00:24:17 +08:00
Embodied Learner
c5780dabeb feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes … (#97)
* feat:add grounded_sam2_tracking_camera_with_continuous_id.py (closes #74)

* update README
2025-05-08 11:02:33 +08:00
Sami Haidar
7fec804683 Pinned setuptools in Dockerfile (#99)
Co-authored-by: Sami Haidar Wehbe <sami@autoenhance.ai>
2025-05-08 11:02:04 +08:00
rentainhe
9412a16276 update DINO-X api to V2 2025-04-21 01:06:01 +08:00
rentainhe
d49257700a update DINO-X api usage to dds v2 2025-04-20 01:04:26 +08:00
rentainhe
3c5a4136d4 update DINO-X api usage to dds v2 2025-04-20 00:38:38 +08:00
Andrew Choi
8238557f52 Add torch2.6 support for ms_deform_attn_cuda (#94) 2025-04-18 00:38:51 +08:00
Reuben Feinman
0bc3970292 update setuptools build requirement to fix build error (#91) 2025-03-24 22:26:04 +08:00
24 changed files with 5439 additions and 168 deletions

2
.gitignore vendored
View File

@@ -145,3 +145,5 @@ dmypy.json
outputs/ outputs/
.idea/ .idea/
tmp/
data/

View File

@@ -27,7 +27,7 @@ WORKDIR /home/appuser/Grounded-SAM-2
# Install essential Python packages # Install essential Python packages
RUN python -m pip install --upgrade pip setuptools wheel numpy \ RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
opencv-python transformers supervision pycocotools addict yapf timm opencv-python transformers supervision pycocotools addict yapf timm
# Install segment_anything package in editable mode # Install segment_anything package in editable mode

View File

@@ -20,6 +20,7 @@ In this repo, we've supported the following demo with **simple implementations**
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience. Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
## Latest updates ## Latest updates
- **2025.04.20**: Update to `dds-cloudapi-sdk` API V2 version. The V1 version in the original API for `Grounding DINO 1.5` and `DINO-X` has been deprecated, please update to the latest `dds-cloudapi-sdk` by `pip install dds-cloudapi-sdk -U` to use `Grounding DINO 1.5 / 1.6` and `DINO-X` models. Please refer to [dds-cloudapi-sdk](https://github.com/deepdataspace/dds-cloudapi-sdk) and our [API docs](https://cloud.deepdataspace.com/docs) to view more details about the update.
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details. - **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
@@ -334,6 +335,16 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
``` ```
### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream)
This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream.
```bash
python grounded_sam2_tracking_camera_with_continuous_id.py
```
## Grounded SAM 2 Florence-2 Demos ## Grounded SAM 2 Florence-2 Demos
### Grounded SAM 2 Florence-2 Image Demo ### Grounded SAM 2 Florence-2 Image Demo

View File

@@ -1,9 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.dinox import DinoxTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
import os import os
import cv2 import cv2
@@ -27,6 +25,7 @@ IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
WITH_SLICE_INFERENCE = False WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480) SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2) OVERLAP_RATIO = (0.2, 0.2)
@@ -48,7 +47,7 @@ config = Config(token)
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task by DetectionTask class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg" # infer_image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
@@ -62,13 +61,18 @@ if WITH_SLICE_INFERENCE:
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile: with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
temp_filename = tmpfile.name temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice) cv2.imwrite(temp_filename, image_slice)
image_url = client.upload_file(temp_filename) infer_image_url = client.upload_file(temp_filename)
task = DinoxTask( task = V2Task(api_path="/v2/task/dinox/detection", api_body={
image_url=image_url, "model": "DINO-X-1.0",
prompts=[TextPrompt(text=TEXT_PROMPT)], "image": infer_image_url,
bbox_threshold=0.25, "prompt": {
targets=[DetectionTarget.BBox], "type":"text",
) "text":TEXT_PROMPT
},
"targets": ["bbox", "mask"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
})
client.run_task(task) client.run_task(task)
result = task.result result = task.result
# detele the tempfile # detele the tempfile
@@ -77,7 +81,7 @@ if WITH_SLICE_INFERENCE:
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_ids = [] class_ids = []
objects = result.objects objects = result["objects"]
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj.bbox)
confidences.append(obj.score) confidences.append(obj.score)
@@ -102,19 +106,26 @@ if WITH_SLICE_INFERENCE:
class_ids = detections.class_id class_ids = detections.class_id
input_boxes = detections.xyxy input_boxes = detections.xyxy
else: else:
image_url = client.upload_file(IMG_PATH) infer_image_url = client.upload_file(IMG_PATH)
task = DinoxTask( task = V2Task(
image_url=image_url, api_path="/v2/task/dinox/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
bbox_threshold=0.25, "model": "DINO-X-1.0",
targets=[DetectionTarget.BBox], "image": infer_image_url,
"prompt": {
"type":"text",
"text":TEXT_PROMPT
},
"targets": ["bbox", "mask"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result["objects"] # the list of detected objects
objects = result.objects # the list of detected objects
input_boxes = [] input_boxes = []
@@ -123,9 +134,9 @@ else:
class_ids = [] class_ids = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
cls_name = obj.category.lower().strip() cls_name = obj["category"].lower().strip()
class_names.append(cls_name) class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])

View File

@@ -1,10 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -27,8 +24,9 @@ TEXT_PROMPT = "car . building ."
IMG_PATH = "notebooks/images/cars.jpg" IMG_PATH = "notebooks/images/cars.jpg"
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt" SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro GROUNDING_MODEL = "GroundingDino-1.5-Pro" # GroundingDino-1.6-Pro
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
WITH_SLICE_INFERENCE = False WITH_SLICE_INFERENCE = False
SLICE_WH = (480, 480) SLICE_WH = (480, 480)
OVERLAP_RATIO = (0.2, 0.2) OVERLAP_RATIO = (0.2, 0.2)
@@ -49,8 +47,7 @@ config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task API
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x] classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
@@ -65,26 +62,33 @@ if WITH_SLICE_INFERENCE:
temp_filename = tmpfile.name temp_filename = tmpfile.name
cv2.imwrite(temp_filename, image_slice) cv2.imwrite(temp_filename, image_slice)
image_url = client.upload_file(temp_filename) image_url = client.upload_file(temp_filename)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": GROUNDING_MODEL,
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model "image": image_url,
bbox_threshold=BOX_THRESHOLD, # box confidence threshold "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
# detele the tempfile # delete the tempfile
os.remove(temp_filename) os.remove(temp_filename)
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_ids = [] class_ids = []
objects = result.objects objects = result["objects"]
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
cls_name = obj.category.lower().strip() cls_name = obj["category"].lower().strip()
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])
# ensure input_boxes with shape (_, 4) # ensure input_boxes with shape (_, 4)
input_boxes = np.array(input_boxes).reshape(-1, 4) input_boxes = np.array(input_boxes).reshape(-1, 4)
@@ -96,7 +100,7 @@ if WITH_SLICE_INFERENCE:
callback=callback, callback=callback,
slice_wh=SLICE_WH, slice_wh=SLICE_WH,
overlap_ratio_wh=OVERLAP_RATIO, overlap_ratio_wh=OVERLAP_RATIO,
iou_threshold=0.5, iou_threshold=IOU_THRESHOLD,
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
) )
detections = slicer(cv2.imread(IMG_PATH)) detections = slicer(cv2.imread(IMG_PATH))
@@ -107,18 +111,25 @@ if WITH_SLICE_INFERENCE:
else: else:
image_url = client.upload_file(IMG_PATH) image_url = client.upload_file(IMG_PATH)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": GROUNDING_MODEL,
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model "image": image_url,
bbox_threshold=BOX_THRESHOLD, # box confidence threshold "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
@@ -127,9 +138,9 @@ else:
class_ids = [] class_ids = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
cls_name = obj.category.lower().strip() cls_name = obj["category"].lower().strip()
class_names.append(cls_name) class_names.append(cls_name)
class_ids.append(class_name_to_id[cls_name]) class_ids.append(class_name_to_id[cls_name])

View File

@@ -23,7 +23,7 @@ parser.add_argument("--text-prompt", default="car. tire.")
parser.add_argument("--img-path", default="notebooks/images/truck.jpg") parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt") parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml") parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
parser.add_argument("--output-dir", default="outputs/test_sam2.1") parser.add_argument("--output-dir", default="outputs/grounded_sam2_hf_demo")
parser.add_argument("--no-dump-json", action="store_true") parser.add_argument("--no-dump-json", action="store_true")
parser.add_argument("--force-cpu", action="store_true") parser.add_argument("--force-cpu", action="store_true")
args = parser.parse_args() args = parser.parse_args()
@@ -44,7 +44,7 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# use bfloat16 # use bfloat16
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

View File

@@ -61,6 +61,7 @@ boxes, confidences, labels = predict(
caption=text, caption=text,
box_threshold=BOX_THRESHOLD, box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD, text_threshold=TEXT_THRESHOLD,
device=DEVICE
) )
# process the box prompt for SAM 2 # process the box prompt for SAM 2
@@ -70,9 +71,9 @@ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
# FIXME: figure how does this influence the G-DINO model # FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8: if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

View File

@@ -0,0 +1,536 @@
import copy
import os
import cv2
import numpy as np
import supervision as sv
import torch
from PIL import Image
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
from utils.common_utils import CommonUtils
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images
# Setup environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
class GroundingDinoPredictor:
"""
Wrapper for using a GroundingDINO model for zero-shot object detection.
"""
def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"):
"""
Initialize the GroundingDINO predictor.
Args:
model_id (str): HuggingFace model ID to load.
device (str): Device to run the model on ('cuda' or 'cpu').
"""
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
self.device = device
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
device
)
def predict(
self,
image: "PIL.Image.Image",
text_prompts: str,
box_threshold=0.25,
text_threshold=0.25,
):
"""
Perform object detection using text prompts.
Args:
image (PIL.Image.Image): Input RGB image.
text_prompts (str): Text prompt describing target objects.
box_threshold (float): Confidence threshold for box selection.
text_threshold (float): Confidence threshold for text match.
Returns:
Tuple[Tensor, List[str]]: Bounding boxes and matched class labels.
"""
inputs = self.processor(
images=image, text=text_prompts, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]],
)
return results[0]["boxes"], results[0]["labels"]
class SAM2ImageSegmentor:
"""
Wrapper class for SAM2-based segmentation given bounding boxes.
"""
def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"):
"""
Initialize the SAM2 image segmentor.
Args:
sam_model_cfg (str): Path to the SAM2 config file.
sam_model_ckpt (str): Path to the SAM2 checkpoint file.
device (str): Device to load the model on ('cuda' or 'cpu').
"""
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
self.device = device
sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device)
self.predictor = SAM2ImagePredictor(sam_model)
def set_image(self, image: np.ndarray):
"""
Set the input image for segmentation.
Args:
image (np.ndarray): RGB image array with shape (H, W, 3).
"""
self.predictor.set_image(image)
def predict_masks_from_boxes(self, boxes: torch.Tensor):
"""
Predict segmentation masks from given bounding boxes.
Args:
boxes (torch.Tensor): Bounding boxes as (N, 4) tensor.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]:
- masks: Binary masks per box, shape (N, H, W)
- scores: Confidence scores for each mask
- logits: Raw logits from the model
"""
masks, scores, logits = self.predictor.predict(
point_coords=None,
point_labels=None,
box=boxes,
multimask_output=False,
)
# Normalize shape to (N, H, W)
if masks.ndim == 2:
masks = masks[None]
scores = scores[None]
logits = logits[None]
elif masks.ndim == 4:
masks = masks.squeeze(1)
return masks, scores, logits
class IncrementalObjectTracker:
def __init__(
self,
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text="car.",
detection_interval=20,
):
"""
Initialize an incremental object tracker using GroundingDINO and SAM2.
Args:
grounding_model_id (str): HuggingFace model ID for GroundingDINO.
sam2_model_cfg (str): Path to SAM2 model config file.
sam2_ckpt_path (str): Path to SAM2 model checkpoint.
device (str): Device to run the models on ('cuda' or 'cpu').
prompt_text (str): Initial text prompt for detection.
detection_interval (int): Frame interval between full detections.
"""
self.device = device
self.detection_interval = detection_interval
self.prompt_text = prompt_text
# Load models
self.grounding_predictor = GroundingDinoPredictor(
model_id=grounding_model_id, device=device
)
self.sam2_segmentor = SAM2ImageSegmentor(
sam_model_cfg=sam2_model_cfg,
sam_model_ckpt=sam2_ckpt_path,
device=device,
)
self.video_predictor = build_sam2_video_predictor(
sam2_model_cfg, sam2_ckpt_path
)
# Initialize inference state
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device)
self.total_frames = 0
self.objects_count = 0
self.frame_cache_limit = detection_interval - 1 # or higher depending on memory
# Store tracking results
self.last_mask_dict = MaskDictionaryModel()
self.track_dict = MaskDictionaryModel()
def add_image(self, image_np: np.ndarray):
"""
Add a new image frame to the tracker and perform detection or tracking update.
Args:
image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8.
Returns:
np.ndarray: Annotated image with object masks and labels.
"""
import numpy as np
from PIL import Image
img_pil = Image.fromarray(image_np)
# Step 1: Perform detection every detection_interval frames
if self.total_frames % self.detection_interval == 0:
if (
self.inference_state["video_height"] is None
or self.inference_state["video_width"] is None
):
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
if self.inference_state["images"].shape[0] > self.frame_cache_limit:
print(
f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory."
)
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
(
self.inference_state["video_height"],
self.inference_state["video_width"],
) = image_np.shape[:2]
# 1.1 GroundingDINO object detection
boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text)
if boxes.shape[0] == 0:
return
# 1.2 SAM2 segmentation from detection boxes
self.sam2_segmentor.set_image(image_np)
masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes)
# 1.3 Build MaskDictionaryModel
mask_dict = MaskDictionaryModel(
promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy"
)
mask_dict.add_new_frame_annotation(
mask_list=torch.tensor(masks).to(self.device),
box_list=torch.tensor(boxes),
label_list=labels,
)
# 1.4 Object ID tracking and IOU-based update
self.objects_count = mask_dict.update_masks(
tracking_annotation_dict=self.last_mask_dict,
iou_threshold=0.3,
objects_count=self.objects_count,
)
# 1.5 Reset video tracker state
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
self.video_predictor.reset_state(self.inference_state)
for object_id, object_info in mask_dict.labels.items():
frame_idx, _, _ = self.video_predictor.add_new_mask(
self.inference_state,
frame_idx,
object_id,
object_info.mask,
)
self.track_dict = copy.deepcopy(mask_dict)
self.last_mask_dict = mask_dict
else:
# Step 2: Use incremental tracking for intermediate frames
frame_idx = self.video_predictor.add_new_frame(
self.inference_state, image_np
)
# Step 3: Tracking propagation using the video predictor
frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame(
inference_state=self.inference_state,
frame_idx=frame_idx,
)
# Step 4: Update the mask dictionary based on tracked masks
frame_masks = MaskDictionaryModel()
for i, obj_id in enumerate(obj_ids):
out_mask = video_res_masks[i] > 0.0
object_info = ObjectInfo(
instance_id=obj_id,
mask=out_mask[0],
class_name=self.track_dict.get_target_class_name(obj_id),
logit=self.track_dict.get_target_logit(obj_id),
)
object_info.update_box()
frame_masks.labels[obj_id] = object_info
frame_masks.mask_name = f"mask_{frame_idx:05d}.npy"
frame_masks.mask_height = out_mask.shape[-2]
frame_masks.mask_width = out_mask.shape[-1]
self.last_mask_dict = copy.deepcopy(frame_masks)
# Step 5: Build mask array
H, W = image_np.shape[:2]
mask_img = torch.zeros((H, W), dtype=torch.int32)
for obj_id, obj_info in self.last_mask_dict.labels.items():
mask_img[obj_info.mask == True] = obj_id
mask_array = mask_img.cpu().numpy()
# Step 6: Visualization
annotated_frame = self.visualize_frame_with_mask_and_metadata(
image_np=image_np,
mask_array=mask_array,
json_metadata=self.last_mask_dict.to_dict(),
)
print(f"[Tracker] Total processed frames: {self.total_frames}")
self.total_frames += 1
torch.cuda.empty_cache()
return annotated_frame
def set_prompt(self, new_prompt: str):
"""
Dynamically update the GroundingDINO prompt and reset tracking state
to force a new object detection.
"""
self.prompt_text = new_prompt
self.total_frames = 0 # Trigger immediate re-detection
self.inference_state = self.video_predictor.init_state()
self.inference_state["images"] = torch.empty(
(0, 3, 1024, 1024), device=self.device
)
self.inference_state["video_height"] = None
self.inference_state["video_width"] = None
print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.")
def save_current_state(self, output_dir, raw_image: np.ndarray = None):
"""
Save the current mask, metadata, raw image, and annotated result.
Args:
output_dir (str): The root output directory.
raw_image (np.ndarray, optional): The original input image (RGB).
"""
mask_data_dir = os.path.join(output_dir, "mask_data")
json_data_dir = os.path.join(output_dir, "json_data")
image_data_dir = os.path.join(output_dir, "images")
vis_data_dir = os.path.join(output_dir, "result")
os.makedirs(mask_data_dir, exist_ok=True)
os.makedirs(json_data_dir, exist_ok=True)
os.makedirs(image_data_dir, exist_ok=True)
os.makedirs(vis_data_dir, exist_ok=True)
frame_masks = self.last_mask_dict
# Ensure mask_name is valid
if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"):
frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy"
base_name = f"image_{self.total_frames:05d}"
# Save segmentation mask
mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width)
for obj_id, obj_info in frame_masks.labels.items():
mask_img[obj_info.mask == True] = obj_id
np.save(
os.path.join(mask_data_dir, frame_masks.mask_name),
mask_img.numpy().astype(np.uint16),
)
# Save metadata as JSON
json_path = os.path.join(json_data_dir, base_name + ".json")
frame_masks.to_json(json_path)
# Save raw input image
if raw_image is not None:
image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr)
# Save annotated image with mask, bounding boxes, and labels
annotated_image = self.visualize_frame_with_mask_and_metadata(
image_np=raw_image,
mask_array=mask_img.numpy().astype(np.uint16),
json_metadata=frame_masks.to_dict(),
)
annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(
os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr
)
print(
f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully."
)
def visualize_frame_with_mask_and_metadata(
self,
image_np: np.ndarray,
mask_array: np.ndarray,
json_metadata: dict,
):
image = image_np.copy()
H, W = image.shape[:2]
# Step 1: Parse metadata and build object entries
metadata_lookup = json_metadata.get("labels", {})
all_object_ids = []
all_object_boxes = []
all_object_classes = []
all_object_masks = []
for obj_id_str, obj_info in metadata_lookup.items():
instance_id = obj_info.get("instance_id")
if instance_id is None or instance_id == 0:
continue
if instance_id not in np.unique(mask_array):
continue
object_mask = mask_array == instance_id
all_object_ids.append(instance_id)
x1 = obj_info.get("x1", 0)
y1 = obj_info.get("y1", 0)
x2 = obj_info.get("x2", 0)
y2 = obj_info.get("y2", 0)
all_object_boxes.append([x1, y1, x2, y2])
all_object_classes.append(obj_info.get("class_name", "unknown"))
all_object_masks.append(object_mask[None]) # Shape (1, H, W)
# Step 2: Check if valid objects exist
if len(all_object_ids) == 0:
print("No valid object instances found in metadata.")
return image
# Step 3: Sort by instance ID
paired = list(
zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes)
)
paired.sort(key=lambda x: x[0])
all_object_ids = [p[0] for p in paired]
all_object_boxes = [p[1] for p in paired]
all_object_masks = [p[2] for p in paired]
all_object_classes = [p[3] for p in paired]
# Step 4: Build detections
all_object_masks = np.concatenate(all_object_masks, axis=0)
detections = sv.Detections(
xyxy=np.array(all_object_boxes),
mask=all_object_masks,
class_id=np.array(all_object_ids, dtype=np.int32),
)
labels = [
f"{instance_id}: {class_name}"
for instance_id, class_name in zip(all_object_ids, all_object_classes)
]
# Step 5: Annotate image
annotated_frame = image.copy()
mask_annotator = sv.MaskAnnotator()
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator()
annotated_frame = mask_annotator.annotate(annotated_frame, detections)
annotated_frame = box_annotator.annotate(annotated_frame, detections)
annotated_frame = label_annotator.annotate(annotated_frame, detections, labels)
return annotated_frame
import os
import cv2
import torch
from utils.common_utils import CommonUtils
def main():
# Parameter settings
output_dir = "./outputs"
prompt_text = "hand."
detection_interval = 20
max_frames = 300 # Maximum number of frames to process (prevents infinite loop)
os.makedirs(output_dir, exist_ok=True)
# Initialize the object tracker
tracker = IncrementalObjectTracker(
grounding_model_id="IDEA-Research/grounding-dino-tiny",
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
device="cuda",
prompt_text=prompt_text,
detection_interval=detection_interval,
)
tracker.set_prompt("person.")
# Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4"))
cap = cv2.VideoCapture(0)
if not cap.isOpened():
print("[Error] Cannot open camera.")
return
print("[Info] Camera opened. Press 'q' to quit.")
frame_idx = 0
try:
while True:
ret, frame = cap.read()
if not ret:
print("[Warning] Failed to capture frame.")
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
print(f"[Frame {frame_idx}] Processing live frame...")
process_image = tracker.add_image(frame_rgb)
if process_image is None or not isinstance(process_image, np.ndarray):
print(f"[Warning] Skipped frame {frame_idx} due to empty result.")
frame_idx += 1
continue
# process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR)
# cv2.imshow("Live Inference", process_image_bgr)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("[Info] Quit signal received.")
# break
tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb)
frame_idx += 1
if frame_idx >= max_frames:
print(f"[Info] Reached max_frames {max_frames}. Stopping.")
break
except KeyboardInterrupt:
print("[Info] Interrupted by user (Ctrl+C).")
finally:
cap.release()
cv2.destroyAllWindows()
print("[Done] Live inference complete.")
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for DINO-X - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk.tasks.dinox import DinoxTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk.tasks.types import DetectionTarget
from dds_cloudapi_sdk import TextPrompt
import os import os
import cv2 import cv2
@@ -30,6 +28,7 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_DINOX = "Your API token" API_TOKEN_FOR_DINOX = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -98,22 +97,29 @@ config = Config(API_TOKEN_FOR_DINOX)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DinoxTask( task = V2Task(
image_url=image_url, api_path="/v2/task/dinox/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
bbox_threshold=0.25, "model": "DINO-X-1.0",
targets=[DetectionTarget.BBox], "image": image_url,
"prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
@@ -121,9 +127,9 @@ confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -1,10 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - Update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -31,6 +28,7 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
API_TOKEN_FOR_GD1_5 = "Your API token" API_TOKEN_FOR_GD1_5 = "Your API token"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"] PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
BOX_THRESHOLD = 0.2 BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
""" """
Step 1: Environment settings and model initialization for SAM 2 Step 1: Environment settings and model initialization for SAM 2
@@ -99,33 +97,38 @@ config = Config(API_TOKEN_FOR_GD1_5)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=TEXT_PROMPT)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": "GroundingDino-1.5-Pro",
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model "image": image_url,
bbox_threshold=BOX_THRESHOLD, "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -1,11 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import torch import torch
@@ -51,6 +47,9 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
# setup the input image and text prompt for SAM 2 and Grounding DINO # setup the input image and text prompt for SAM 2 and Grounding DINO
# VERY important: text queries need to be lowercased + end with a dot # VERY important: text queries need to be lowercased + end with a dot
text = "car." text = "car."
BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
GROUNDING_MODEL = "GroundingDino-1.6-Pro" # 使用字符串替代枚举值
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg` # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "notebooks/videos/car" video_dir = "notebooks/videos/car"
@@ -102,24 +101,32 @@ for start_frame_idx in range(0, len(frame_names), step):
client = Client(config) client = Client(config)
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text=text)], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": GROUNDING_MODEL,
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model "image": image_url,
"prompt": {
"type": "text",
"text": text
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
confidences = [] confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)
OBJECTS = class_names OBJECTS = class_names
@@ -154,7 +161,7 @@ for start_frame_idx in range(0, len(frame_names), step):
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count) objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=IOU_THRESHOLD, objects_count=objects_count)
print("objects_count", objects_count) print("objects_count", objects_count)
else: else:

View File

@@ -1,10 +1,7 @@
# dds cloudapi for Grounding DINO 1.5 # dds cloudapi for Grounding DINO 1.5 - update to V2Task API
from dds_cloudapi_sdk import Config from dds_cloudapi_sdk import Config
from dds_cloudapi_sdk import Client from dds_cloudapi_sdk import Client
from dds_cloudapi_sdk import DetectionTask from dds_cloudapi_sdk.tasks.v2_task import V2Task
from dds_cloudapi_sdk import TextPrompt
from dds_cloudapi_sdk import DetectionModel
from dds_cloudapi_sdk import DetectionTarget
import os import os
import cv2 import cv2
@@ -54,6 +51,11 @@ inference_state = video_predictor.init_state(video_path=video_dir)
ann_frame_idx = 0 # the frame index we interact with ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# 添加参数设置
TEXT_PROMPT = "children. pillow"
BOX_THRESHOLD = 0.2
IOU_THRESHOLD = 0.8
""" """
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
@@ -70,23 +72,29 @@ config = Config(token)
# Step 2: initialize the client # Step 2: initialize the client
client = Client(config) client = Client(config)
# Step 3: run the task by DetectionTask class # Step 3: run the task using V2Task class
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
# if you are processing local image file, upload them to DDS server to get the image url # if you are processing local image file, upload them to DDS server to get the image url
image_url = client.upload_file(img_path) image_url = client.upload_file(img_path)
task = DetectionTask( task = V2Task(
image_url=image_url, api_path="/v2/task/grounding_dino/detection",
prompts=[TextPrompt(text="children. pillow")], api_body={
targets=[DetectionTarget.BBox], # detect bbox "model": "GroundingDino-1.5-Pro",
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model "image": image_url,
bbox_threshold=0.2, "prompt": {
"type": "text",
"text": TEXT_PROMPT
},
"targets": ["bbox"],
"bbox_threshold": BOX_THRESHOLD,
"iou_threshold": IOU_THRESHOLD,
}
) )
client.run_task(task) client.run_task(task)
result = task.result result = task.result
objects = result.objects # the list of detected objects objects = result["objects"] # the list of detected objects
input_boxes = [] input_boxes = []
@@ -94,9 +102,9 @@ confidences = []
class_names = [] class_names = []
for idx, obj in enumerate(objects): for idx, obj in enumerate(objects):
input_boxes.append(obj.bbox) input_boxes.append(obj["bbox"])
confidences.append(obj.score) confidences.append(obj["score"])
class_names.append(obj.category) class_names.append(obj["category"])
input_boxes = np.array(input_boxes) input_boxes = np.array(input_boxes)

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.layers import DropPath, to_2tuple, trunc_normal_
from grounding_dino.groundingdino.util.misc import NestedTensor from grounding_dino.groundingdino.util.misc import NestedTensor
@@ -113,7 +113,7 @@ class WindowAttention(nn.Module):
# get pair-wise relative position index for each token inside the window # get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0]) coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1]) coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

View File

@@ -15,6 +15,19 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/version.h>
// Check PyTorch version and define appropriate macros
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
// PyTorch 2.x and above
#define GET_TENSOR_TYPE(x) x.scalar_type()
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
#else
// PyTorch 1.x
#define GET_TENSOR_TYPE(x) x.type()
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
#endif
namespace groundingdino { namespace groundingdino {
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
const int batch = value.size(0); const int batch = value.size(0);
const int spatial_size = value.size(1); const int spatial_size = value.size(1);
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto columns = output_n.select(0, n); auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size, value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(), spatial_shapes.data<int64_t>(),
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
const int batch = value.size(0); const int batch = value.size(0);
const int spatial_size = value.size(1); const int spatial_size = value.size(1);
@@ -132,7 +145,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n) for (int n = 0; n < batch/im2col_step_; ++n)
{ {
auto grad_output_g = grad_output_n.select(0, n); auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(), grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size, value.data<scalar_t>() + n * im2col_step_ * per_value_size,

View File

@@ -8,7 +8,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.models.layers import DropPath from timm.layers import DropPath
class FeatureResizer(nn.Module): class FeatureResizer(nn.Module):

View File

@@ -470,6 +470,7 @@ class TransformerEncoder(nn.Module):
ref_y, ref_x = torch.meshgrid( ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
indexing="ij"
) )
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
@@ -859,7 +860,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_ffn(self, tgt): def forward_ffn(self, tgt):
with torch.cuda.amp.autocast(enabled=False): with torch.amp.autocast("cuda", enabled=False):
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout4(tgt2) tgt = tgt + self.dropout4(tgt2)
tgt = self.norm3(tgt) tgt = self.norm3(tgt)

View File

@@ -79,6 +79,7 @@ def gen_encoder_output_proposals(
grid_y, grid_x = torch.meshgrid( grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device), torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
indexing="ij"
) )
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2

View File

@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
y = torch.arange(0, h, dtype=torch.float) y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x) y, x = torch.meshgrid(y, x, indexing="ij")
x_mask = masks * x.unsqueeze(0) x_mask = masks * x.unsqueeze(0)
x_max = x_mask.flatten(1).max(-1)[0] x_max = x_mask.flatten(1).max(-1)[0]

View File

@@ -63,6 +63,7 @@ def predict(
model = model.to(device) model = model.to(device)
image = image.to(device) image = image.to(device)
model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(image[None], captions=[caption]) outputs = model(image[None], captions=[caption])

View File

@@ -1,6 +1,68 @@
[build-system] [build-system]
requires = [ requires = ["setuptools>=61.0", "wheel"]
"setuptools>=61.0",
"torch>=2.3.1",
]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project]
name = "Grounded-SAM-2"
version = "1.0"
description = "Grounded SAM 2: Ground and Track Anything in Videos"
readme = "README.md"
requires-python = ">=3.10.0"
license = { text = "Apache 2.0" }
authors = [{ name = "Meta AI", email = "segment-anything@meta.com" }]
keywords = ["segmentation", "computer vision", "deep learning"]
dependencies = [
"torch>=2.3.1",
"torchvision>=0.18.1",
"numpy>=1.24.4",
"tqdm>=4.66.1",
"hydra-core>=1.3.2",
"iopath>=0.1.10",
"pillow>=9.4.0",
"opencv-python-headless>=4.11.0.86",
"supervision>=0.26.1",
"pycocotools>=2.0.10",
"transformers>=4.55.1",
"addict>=2.4.0",
"yapf>=0.43.0",
"timm>=1.0.19",
"pdf2image>=1.17.0",
]
[project.optional-dependencies]
notebooks = [
"matplotlib>=3.9.1",
"jupyter>=1.0.0",
"opencv-python>=4.7.0",
"eva-decord>=0.6.1",
]
interactive-demo = [
"Flask>=3.0.3",
"Flask-Cors>=5.0.0",
"av>=13.0.0",
"dataclasses-json>=0.6.7",
"eva-decord>=0.6.1",
"gunicorn>=23.0.0",
"imagesize>=1.4.1",
"pycocotools>=2.0.8",
"strawberry-graphql>=0.243.0",
]
dev = [
"black==24.2.0",
"usort==1.0.2",
"ufmt==2.0.0b2",
"fvcore>=0.1.5.post20221221",
"pandas>=2.2.2",
"scikit-image>=0.24.0",
"tensorboard>=2.17.0",
"pycocotools>=2.0.8",
"tensordict>=0.5.0",
"opencv-python>=4.7.0",
"submitit>=1.5.1",
]
[tool.setuptools]
# extensions = [{ name = "sam2._C", sources = ["sam2/csrc/connected_components.cu"] }]
packages = ["sam2", "grounding_dino"]

View File

@@ -12,7 +12,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames ,process_stream_frame
class SAM2VideoPredictor(SAM2Base): class SAM2VideoPredictor(SAM2Base):
@@ -43,13 +43,16 @@ class SAM2VideoPredictor(SAM2Base):
@torch.inference_mode() @torch.inference_mode()
def init_state( def init_state(
self, self,
video_path, video_path=None,
offload_video_to_cpu=False, offload_video_to_cpu=False,
offload_state_to_cpu=False, offload_state_to_cpu=False,
async_loading_frames=False, async_loading_frames=False,
): ):
"""Initialize an inference state.""" """Initialize an inference state."""
compute_device = self.device # device of the model compute_device = self.device # device of the model
inference_state = {}
if video_path is not None:
# Preload video frames from file
images, video_height, video_width = load_video_frames( images, video_height, video_width = load_video_frames(
video_path=video_path, video_path=video_path,
image_size=self.image_size, image_size=self.image_size,
@@ -57,9 +60,16 @@ class SAM2VideoPredictor(SAM2Base):
async_loading_frames=async_loading_frames, async_loading_frames=async_loading_frames,
compute_device=compute_device, compute_device=compute_device,
) )
inference_state = {}
inference_state["images"] = images inference_state["images"] = images
inference_state["num_frames"] = len(images) inference_state["num_frames"] = len(images)
else:
# Real-time streaming mode
print("Real-time streaming mode: waiting for first image input...")
images = None
video_height, video_width = None, None
inference_state["images"] = None
inference_state["num_frames"] = 0
# whether to offload the video frames to CPU memory # whether to offload the video frames to CPU memory
# turning on this option saves the GPU memory with only a very small overhead # turning on this option saves the GPU memory with only a very small overhead
inference_state["offload_video_to_cpu"] = offload_video_to_cpu inference_state["offload_video_to_cpu"] = offload_video_to_cpu
@@ -107,7 +117,9 @@ class SAM2VideoPredictor(SAM2Base):
inference_state["tracking_has_started"] = False inference_state["tracking_has_started"] = False
inference_state["frames_already_tracked"] = {} inference_state["frames_already_tracked"] = {}
# Warm up the visual backbone and cache the image feature on frame 0 # Warm up the visual backbone and cache the image feature on frame 0
if video_path is not None:
self._get_image_feature(inference_state, frame_idx=0, batch_size=1) self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
return inference_state return inference_state
@classmethod @classmethod
@@ -743,6 +755,133 @@ class SAM2VideoPredictor(SAM2Base):
inference_state, pred_masks inference_state, pred_masks
) )
yield frame_idx, obj_ids, video_res_masks yield frame_idx, obj_ids, video_res_masks
@torch.inference_mode()
def add_new_frame(self, inference_state, new_image):
"""
Add a new frame to the inference state and cache its image features.
Args:
inference_state (dict): The current inference state containing cached frames, features, and tracking information.
new_image (Tensor or ndarray): The input image frame (in HWC or CHW format depending on upstream processing).
Returns:
frame_idx (int): The index of the newly added frame within the inference state.
"""
device = inference_state["device"]
# Preprocess the input frame and convert it to a normalized tensor
img_tensor, orig_h, orig_w = process_stream_frame(
img_array=new_image,
image_size=self.image_size,
offload_to_cpu=False,
compute_device=device,
)
# Handle initialization of the image sequence if this is the first frame
images = inference_state.get("images", None)
if images is None or (isinstance(images, list) and len(images) == 0):
# First frame: initialize image tensor batch
inference_state["images"] = img_tensor.unsqueeze(0) # Shape: [1, C, H, W]
else:
# Append to existing tensor batch
if isinstance(images, list):
raise ValueError(
"inference_state['images'] should be a Tensor, not a list after initialization."
)
img_tensor = img_tensor.to(images.device)
inference_state["images"] = torch.cat(
[images, img_tensor.unsqueeze(0)], dim=0
)
# Update frame count and compute new frame index
inference_state["num_frames"] = inference_state["images"].shape[0]
frame_idx = inference_state["num_frames"] - 1
# Cache visual features for the newly added frame
image_batch = img_tensor.float().unsqueeze(0) # Shape: [1, C, H, W]
backbone_out = self.forward_image(image_batch)
inference_state["cached_features"][frame_idx] = (image_batch, backbone_out)
return frame_idx
@torch.inference_mode()
def infer_single_frame(self, inference_state, frame_idx):
"""
Run inference on a single frame using existing points/masks in the inference state.
Args:
inference_state (dict): The current state of the tracking process.
frame_idx (int): Index of the frame to run inference on.
Returns:
frame_idx (int): Same as input; the index of the processed frame.
obj_ids (list): List of currently tracked object IDs.
video_res_masks (Tensor): Segmentation masks predicted for the objects in the frame.
"""
if frame_idx >= inference_state["num_frames"]:
raise ValueError(
f"Frame index {frame_idx} out of range (num_frames={inference_state['num_frames']})."
)
self.propagate_in_video_preflight(inference_state)
output_dict = inference_state["output_dict"]
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
batch_size = self._get_obj_num(inference_state)
# Ensure that initial conditioning points exist
if len(output_dict["cond_frame_outputs"]) == 0:
raise RuntimeError(
"No conditioning points provided. Please add points before inference."
)
# Decide whether to clear nearby memory based on number of objects
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
)
obj_ids = inference_state["obj_ids"]
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
# If output is already consolidated with conditioning inputs
storage_key = "cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
if clear_non_cond_mem:
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
# If output was inferred without conditioning
storage_key = "non_cond_frame_outputs"
current_out = output_dict[storage_key][frame_idx]
pred_masks = current_out["pred_masks"]
else:
# Run model inference for this frame
storage_key = "non_cond_frame_outputs"
current_out, pred_masks = self._run_single_frame_inference(
inference_state=inference_state,
output_dict=output_dict,
frame_idx=frame_idx,
batch_size=batch_size,
is_init_cond_frame=False,
point_inputs=None,
mask_inputs=None,
reverse=False,
run_mem_encoder=True,
)
output_dict[storage_key][frame_idx] = current_out
# Organize per-object outputs and mark frame as tracked
self._add_output_per_object(
inference_state, frame_idx, current_out, storage_key
)
inference_state["frames_already_tracked"][frame_idx] = {"reverse": False}
# Convert output to original video resolution
_, video_res_masks = self._get_orig_video_res_output(
inference_state, pred_masks
)
return frame_idx, obj_ids, video_res_masks
def _add_output_per_object( def _add_output_per_object(
self, inference_state, frame_idx, current_out, storage_key self, inference_state, frame_idx, current_out, storage_key

View File

@@ -8,6 +8,7 @@ import os
import warnings import warnings
from threading import Thread from threading import Thread
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@@ -209,6 +210,74 @@ def load_video_frames(
"Only MP4 video and JPEG folder are supported at this moment" "Only MP4 video and JPEG folder are supported at this moment"
) )
def process_stream_frame(
img_array: np.ndarray,
image_size: int,
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
offload_to_cpu: bool = False,
compute_device: torch.device = torch.device("cuda"),
):
"""
Convert a raw image array (H,W,3 or 3,H,W) into a modelready tensor.
Steps
-----
1. Resize the shorter side to `image_size`, keeping aspect ratio,
then centercrop/pad to `image_size` × `image_size`.
2. Change layout to [3, H, W] and cast to float32 in [0,1].
3. Normalise with ImageNet statistics.
4. Optionally move to `compute_device`.
Returns
-------
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
orig_h : int
orig_w : int
"""
# ↪ uses your existing helper so behaviour matches the batch loader
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
# Normalisation (done *after* potential device move for efficiency)
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
if not offload_to_cpu:
img_tensor = img_tensor.to(compute_device)
img_mean_t = img_mean_t.to(compute_device)
img_std_t = img_std_t.to(compute_device)
img_tensor.sub_(img_mean_t).div_(img_std_t)
return img_tensor, orig_h, orig_w
def _resize_and_convert_to_tensor(img_array, image_size):
"""
Resize the input image array and convert it into a tensor.
Also return original image height and width.
"""
# Convert numpy array to PIL image and ensure RGB
img_pil = Image.fromarray(img_array).convert("RGB")
# Save original size (PIL: size = (width, height))
video_width, video_height = img_pil.size
# Resize with high-quality LANCZOS filter
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
# Convert resized image back to numpy and then to float tensor
img_resized_array = np.array(img_resized)
if img_resized_array.dtype == np.uint8:
img_resized_array = img_resized_array / 255.0
else:
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
# Convert to PyTorch tensor and permute to [C, H, W]
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
return img_tensor, video_height, video_width
def load_video_frames_from_jpg_images( def load_video_frames_from_jpg_images(
video_path, video_path,

View File

@@ -623,7 +623,7 @@ class Trainer:
# compute output # compute output
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast( with torch.amp.autocast("cuda",
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False), enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
dtype=( dtype=(
get_amp_type(self.optim_conf.amp.amp_dtype) get_amp_type(self.optim_conf.amp.amp_dtype)
@@ -858,7 +858,8 @@ class Trainer:
# grads will also update a model even if the step doesn't produce # grads will also update a model even if the step doesn't produce
# gradients # gradients
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast( with torch.amp.autocast(
"cuda",
enabled=self.optim_conf.amp.enabled, enabled=self.optim_conf.amp.enabled,
dtype=get_amp_type(self.optim_conf.amp.amp_dtype), dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
): ):

4388
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff