Compare commits
11 Commits
update_sam
...
main
Author | SHA1 | Date | |
---|---|---|---|
![]() |
33303aa62f | ||
![]() |
34b17b0280 | ||
![]() |
2111d9c52c | ||
![]() |
75aaf0c3ae | ||
![]() |
c5780dabeb | ||
![]() |
7fec804683 | ||
![]() |
9412a16276 | ||
![]() |
d49257700a | ||
![]() |
3c5a4136d4 | ||
![]() |
8238557f52 | ||
![]() |
0bc3970292 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -145,3 +145,5 @@ dmypy.json
|
|||||||
outputs/
|
outputs/
|
||||||
|
|
||||||
.idea/
|
.idea/
|
||||||
|
tmp/
|
||||||
|
data/
|
||||||
|
@@ -27,7 +27,7 @@ WORKDIR /home/appuser/Grounded-SAM-2
|
|||||||
|
|
||||||
|
|
||||||
# Install essential Python packages
|
# Install essential Python packages
|
||||||
RUN python -m pip install --upgrade pip setuptools wheel numpy \
|
RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
|
||||||
opencv-python transformers supervision pycocotools addict yapf timm
|
opencv-python transformers supervision pycocotools addict yapf timm
|
||||||
|
|
||||||
# Install segment_anything package in editable mode
|
# Install segment_anything package in editable mode
|
||||||
|
11
README.md
11
README.md
@@ -20,6 +20,7 @@ In this repo, we've supported the following demo with **simple implementations**
|
|||||||
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
|
Grounded SAM 2 does not introduce significant methodological changes compared to [Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks](https://arxiv.org/abs/2401.14159). Both approaches leverage the capabilities of open-world models to address complex visual tasks. Consequently, we try to **simplify the code implementation** in this repository, aiming to enhance user convenience.
|
||||||
|
|
||||||
## Latest updates
|
## Latest updates
|
||||||
|
- **2025.04.20**: Update to `dds-cloudapi-sdk` API V2 version. The V1 version in the original API for `Grounding DINO 1.5` and `DINO-X` has been deprecated, please update to the latest `dds-cloudapi-sdk` by `pip install dds-cloudapi-sdk -U` to use `Grounding DINO 1.5 / 1.6` and `DINO-X` models. Please refer to [dds-cloudapi-sdk](https://github.com/deepdataspace/dds-cloudapi-sdk) and our [API docs](https://cloud.deepdataspace.com/docs) to view more details about the update.
|
||||||
|
|
||||||
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
|
- **2024.12.02**: Support **DINO-X with SAM 2** demos (including object segmentation and tracking), please install the latest version of `dds-cloudapi-sdk==0.3.3` and refer to [Grounded SAM 2 (with DINO-X)](#grounded-sam-2-image-demo-with-dino-x) and [Grounded SAM 2 Video (with DINO-X)](#grounded-sam-2-video-object-tracking-demo-with-custom-video-input-with-dino-x) for more details.
|
||||||
|
|
||||||
@@ -334,6 +335,16 @@ python grounded_sam2_tracking_demo_with_continuous_id_plus.py
|
|||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Grounded-SAM-2 Real-Time Object Tracking with Continuous ID (Live Video / Camera Stream)
|
||||||
|
|
||||||
|
This method enables **real-time object tracking** with **ID continuity** from a live camera or video stream.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python grounded_sam2_tracking_camera_with_continuous_id.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Grounded SAM 2 Florence-2 Demos
|
## Grounded SAM 2 Florence-2 Demos
|
||||||
### Grounded SAM 2 Florence-2 Image Demo
|
### Grounded SAM 2 Florence-2 Image Demo
|
||||||
|
|
||||||
|
@@ -1,9 +1,7 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for Grounding DINO 1.5
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||||
from dds_cloudapi_sdk.tasks.types import DetectionTarget
|
|
||||||
from dds_cloudapi_sdk import TextPrompt
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -27,6 +25,7 @@ IMG_PATH = "notebooks/images/cars.jpg"
|
|||||||
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
||||||
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
|
IOU_THRESHOLD = 0.8
|
||||||
WITH_SLICE_INFERENCE = False
|
WITH_SLICE_INFERENCE = False
|
||||||
SLICE_WH = (480, 480)
|
SLICE_WH = (480, 480)
|
||||||
OVERLAP_RATIO = (0.2, 0.2)
|
OVERLAP_RATIO = (0.2, 0.2)
|
||||||
@@ -48,7 +47,7 @@ config = Config(token)
|
|||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task by DetectionTask class
|
||||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
# infer_image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
|
|
||||||
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
||||||
@@ -62,13 +61,18 @@ if WITH_SLICE_INFERENCE:
|
|||||||
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
|
||||||
temp_filename = tmpfile.name
|
temp_filename = tmpfile.name
|
||||||
cv2.imwrite(temp_filename, image_slice)
|
cv2.imwrite(temp_filename, image_slice)
|
||||||
image_url = client.upload_file(temp_filename)
|
infer_image_url = client.upload_file(temp_filename)
|
||||||
task = DinoxTask(
|
task = V2Task(api_path="/v2/task/dinox/detection", api_body={
|
||||||
image_url=image_url,
|
"model": "DINO-X-1.0",
|
||||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
"image": infer_image_url,
|
||||||
bbox_threshold=0.25,
|
"prompt": {
|
||||||
targets=[DetectionTarget.BBox],
|
"type":"text",
|
||||||
)
|
"text":TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox", "mask"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
})
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
# detele the tempfile
|
# detele the tempfile
|
||||||
@@ -77,7 +81,7 @@ if WITH_SLICE_INFERENCE:
|
|||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_ids = []
|
class_ids = []
|
||||||
objects = result.objects
|
objects = result["objects"]
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj.bbox)
|
||||||
confidences.append(obj.score)
|
confidences.append(obj.score)
|
||||||
@@ -102,19 +106,26 @@ if WITH_SLICE_INFERENCE:
|
|||||||
class_ids = detections.class_id
|
class_ids = detections.class_id
|
||||||
input_boxes = detections.xyxy
|
input_boxes = detections.xyxy
|
||||||
else:
|
else:
|
||||||
image_url = client.upload_file(IMG_PATH)
|
infer_image_url = client.upload_file(IMG_PATH)
|
||||||
|
|
||||||
task = DinoxTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/dinox/detection",
|
||||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
api_body={
|
||||||
bbox_threshold=0.25,
|
"model": "DINO-X-1.0",
|
||||||
targets=[DetectionTarget.BBox],
|
"image": infer_image_url,
|
||||||
|
"prompt": {
|
||||||
|
"type":"text",
|
||||||
|
"text":TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox", "mask"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
objects = result["objects"] # the list of detected objects
|
||||||
objects = result.objects # the list of detected objects
|
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -123,9 +134,9 @@ else:
|
|||||||
class_ids = []
|
class_ids = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
cls_name = obj.category.lower().strip()
|
cls_name = obj["category"].lower().strip()
|
||||||
class_names.append(cls_name)
|
class_names.append(cls_name)
|
||||||
class_ids.append(class_name_to_id[cls_name])
|
class_ids.append(class_name_to_id[cls_name])
|
||||||
|
|
||||||
|
@@ -1,10 +1,7 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk import DetectionTask
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||||
from dds_cloudapi_sdk import TextPrompt
|
|
||||||
from dds_cloudapi_sdk import DetectionModel
|
|
||||||
from dds_cloudapi_sdk import DetectionTarget
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -27,8 +24,9 @@ TEXT_PROMPT = "car . building ."
|
|||||||
IMG_PATH = "notebooks/images/cars.jpg"
|
IMG_PATH = "notebooks/images/cars.jpg"
|
||||||
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
SAM2_CHECKPOINT = "./checkpoints/sam2.1_hiera_large.pt"
|
||||||
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
SAM2_MODEL_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
GROUNDING_MODEL = DetectionModel.GDino1_5_Pro # DetectionModel.GDino1_6_Pro
|
GROUNDING_MODEL = "GroundingDino-1.5-Pro" # GroundingDino-1.6-Pro
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
|
IOU_THRESHOLD = 0.8
|
||||||
WITH_SLICE_INFERENCE = False
|
WITH_SLICE_INFERENCE = False
|
||||||
SLICE_WH = (480, 480)
|
SLICE_WH = (480, 480)
|
||||||
OVERLAP_RATIO = (0.2, 0.2)
|
OVERLAP_RATIO = (0.2, 0.2)
|
||||||
@@ -49,8 +47,7 @@ config = Config(token)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task using V2Task API
|
||||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
|
|
||||||
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
||||||
@@ -65,26 +62,33 @@ if WITH_SLICE_INFERENCE:
|
|||||||
temp_filename = tmpfile.name
|
temp_filename = tmpfile.name
|
||||||
cv2.imwrite(temp_filename, image_slice)
|
cv2.imwrite(temp_filename, image_slice)
|
||||||
image_url = client.upload_file(temp_filename)
|
image_url = client.upload_file(temp_filename)
|
||||||
task = DetectionTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/grounding_dino/detection",
|
||||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
api_body={
|
||||||
targets=[DetectionTarget.BBox], # detect bbox
|
"model": GROUNDING_MODEL,
|
||||||
model=GROUNDING_MODEL, # detect with GroundingDino-1.5-Pro model
|
"image": image_url,
|
||||||
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
"prompt": {
|
||||||
|
"type": "text",
|
||||||
|
"text": TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
# detele the tempfile
|
# delete the tempfile
|
||||||
os.remove(temp_filename)
|
os.remove(temp_filename)
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_ids = []
|
class_ids = []
|
||||||
objects = result.objects
|
objects = result["objects"]
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
cls_name = obj.category.lower().strip()
|
cls_name = obj["category"].lower().strip()
|
||||||
class_ids.append(class_name_to_id[cls_name])
|
class_ids.append(class_name_to_id[cls_name])
|
||||||
# ensure input_boxes with shape (_, 4)
|
# ensure input_boxes with shape (_, 4)
|
||||||
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
input_boxes = np.array(input_boxes).reshape(-1, 4)
|
||||||
@@ -96,7 +100,7 @@ if WITH_SLICE_INFERENCE:
|
|||||||
callback=callback,
|
callback=callback,
|
||||||
slice_wh=SLICE_WH,
|
slice_wh=SLICE_WH,
|
||||||
overlap_ratio_wh=OVERLAP_RATIO,
|
overlap_ratio_wh=OVERLAP_RATIO,
|
||||||
iou_threshold=0.5,
|
iou_threshold=IOU_THRESHOLD,
|
||||||
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION
|
||||||
)
|
)
|
||||||
detections = slicer(cv2.imread(IMG_PATH))
|
detections = slicer(cv2.imread(IMG_PATH))
|
||||||
@@ -107,18 +111,25 @@ if WITH_SLICE_INFERENCE:
|
|||||||
else:
|
else:
|
||||||
image_url = client.upload_file(IMG_PATH)
|
image_url = client.upload_file(IMG_PATH)
|
||||||
|
|
||||||
task = DetectionTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/grounding_dino/detection",
|
||||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
api_body={
|
||||||
targets=[DetectionTarget.BBox], # detect bbox
|
"model": GROUNDING_MODEL,
|
||||||
model=GROUNDING_MODEL, # detect with GroundingDINO-1.5-Pro model
|
"image": image_url,
|
||||||
bbox_threshold=BOX_THRESHOLD, # box confidence threshold
|
"prompt": {
|
||||||
|
"type": "text",
|
||||||
|
"text": TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result.objects # the list of detected objects
|
objects = result["objects"] # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -127,9 +138,9 @@ else:
|
|||||||
class_ids = []
|
class_ids = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
cls_name = obj.category.lower().strip()
|
cls_name = obj["category"].lower().strip()
|
||||||
class_names.append(cls_name)
|
class_names.append(cls_name)
|
||||||
class_ids.append(class_name_to_id[cls_name])
|
class_ids.append(class_name_to_id[cls_name])
|
||||||
|
|
||||||
|
@@ -23,7 +23,7 @@ parser.add_argument("--text-prompt", default="car. tire.")
|
|||||||
parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
|
parser.add_argument("--img-path", default="notebooks/images/truck.jpg")
|
||||||
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
|
parser.add_argument("--sam2-checkpoint", default="./checkpoints/sam2.1_hiera_large.pt")
|
||||||
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
|
parser.add_argument("--sam2-model-config", default="configs/sam2.1/sam2.1_hiera_l.yaml")
|
||||||
parser.add_argument("--output-dir", default="outputs/test_sam2.1")
|
parser.add_argument("--output-dir", default="outputs/grounded_sam2_hf_demo")
|
||||||
parser.add_argument("--no-dump-json", action="store_true")
|
parser.add_argument("--no-dump-json", action="store_true")
|
||||||
parser.add_argument("--force-cpu", action="store_true")
|
parser.add_argument("--force-cpu", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -44,7 +44,7 @@ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|||||||
# use bfloat16
|
# use bfloat16
|
||||||
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
if torch.cuda.get_device_properties(0).major >= 8:
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
||||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
@@ -61,6 +61,7 @@ boxes, confidences, labels = predict(
|
|||||||
caption=text,
|
caption=text,
|
||||||
box_threshold=BOX_THRESHOLD,
|
box_threshold=BOX_THRESHOLD,
|
||||||
text_threshold=TEXT_THRESHOLD,
|
text_threshold=TEXT_THRESHOLD,
|
||||||
|
device=DEVICE
|
||||||
)
|
)
|
||||||
|
|
||||||
# process the box prompt for SAM 2
|
# process the box prompt for SAM 2
|
||||||
@@ -70,9 +71,9 @@ input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
|||||||
|
|
||||||
|
|
||||||
# FIXME: figure how does this influence the G-DINO model
|
# FIXME: figure how does this influence the G-DINO model
|
||||||
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
if torch.cuda.get_device_properties(0).major >= 8:
|
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
||||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
536
grounded_sam2_tracking_camera_with_continuous_id.py
Normal file
536
grounded_sam2_tracking_camera_with_continuous_id.py
Normal file
@@ -0,0 +1,536 @@
|
|||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import supervision as sv
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
||||||
|
from utils.common_utils import CommonUtils
|
||||||
|
from utils.mask_dictionary_model import MaskDictionaryModel, ObjectInfo
|
||||||
|
from utils.track_utils import sample_points_from_masks
|
||||||
|
from utils.video_utils import create_video_from_images
|
||||||
|
|
||||||
|
# Setup environment
|
||||||
|
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
|
class GroundingDinoPredictor:
|
||||||
|
"""
|
||||||
|
Wrapper for using a GroundingDINO model for zero-shot object detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_id="IDEA-Research/grounding-dino-tiny", device="cuda"):
|
||||||
|
"""
|
||||||
|
Initialize the GroundingDINO predictor.
|
||||||
|
Args:
|
||||||
|
model_id (str): HuggingFace model ID to load.
|
||||||
|
device (str): Device to run the model on ('cuda' or 'cpu').
|
||||||
|
"""
|
||||||
|
from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||||
|
self.model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
image: "PIL.Image.Image",
|
||||||
|
text_prompts: str,
|
||||||
|
box_threshold=0.25,
|
||||||
|
text_threshold=0.25,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform object detection using text prompts.
|
||||||
|
Args:
|
||||||
|
image (PIL.Image.Image): Input RGB image.
|
||||||
|
text_prompts (str): Text prompt describing target objects.
|
||||||
|
box_threshold (float): Confidence threshold for box selection.
|
||||||
|
text_threshold (float): Confidence threshold for text match.
|
||||||
|
Returns:
|
||||||
|
Tuple[Tensor, List[str]]: Bounding boxes and matched class labels.
|
||||||
|
"""
|
||||||
|
inputs = self.processor(
|
||||||
|
images=image, text=text_prompts, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
|
||||||
|
results = self.processor.post_process_grounded_object_detection(
|
||||||
|
outputs,
|
||||||
|
inputs.input_ids,
|
||||||
|
box_threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold,
|
||||||
|
target_sizes=[image.size[::-1]],
|
||||||
|
)
|
||||||
|
|
||||||
|
return results[0]["boxes"], results[0]["labels"]
|
||||||
|
|
||||||
|
|
||||||
|
class SAM2ImageSegmentor:
|
||||||
|
"""
|
||||||
|
Wrapper class for SAM2-based segmentation given bounding boxes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sam_model_cfg: str, sam_model_ckpt: str, device="cuda"):
|
||||||
|
"""
|
||||||
|
Initialize the SAM2 image segmentor.
|
||||||
|
Args:
|
||||||
|
sam_model_cfg (str): Path to the SAM2 config file.
|
||||||
|
sam_model_ckpt (str): Path to the SAM2 checkpoint file.
|
||||||
|
device (str): Device to load the model on ('cuda' or 'cpu').
|
||||||
|
"""
|
||||||
|
from sam2.build_sam import build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
sam_model = build_sam2(sam_model_cfg, sam_model_ckpt, device=device)
|
||||||
|
self.predictor = SAM2ImagePredictor(sam_model)
|
||||||
|
|
||||||
|
def set_image(self, image: np.ndarray):
|
||||||
|
"""
|
||||||
|
Set the input image for segmentation.
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): RGB image array with shape (H, W, 3).
|
||||||
|
"""
|
||||||
|
self.predictor.set_image(image)
|
||||||
|
|
||||||
|
def predict_masks_from_boxes(self, boxes: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Predict segmentation masks from given bounding boxes.
|
||||||
|
Args:
|
||||||
|
boxes (torch.Tensor): Bounding boxes as (N, 4) tensor.
|
||||||
|
Returns:
|
||||||
|
Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
- masks: Binary masks per box, shape (N, H, W)
|
||||||
|
- scores: Confidence scores for each mask
|
||||||
|
- logits: Raw logits from the model
|
||||||
|
"""
|
||||||
|
masks, scores, logits = self.predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize shape to (N, H, W)
|
||||||
|
if masks.ndim == 2:
|
||||||
|
masks = masks[None]
|
||||||
|
scores = scores[None]
|
||||||
|
logits = logits[None]
|
||||||
|
elif masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
return masks, scores, logits
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalObjectTracker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grounding_model_id="IDEA-Research/grounding-dino-tiny",
|
||||||
|
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
|
||||||
|
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
|
||||||
|
device="cuda",
|
||||||
|
prompt_text="car.",
|
||||||
|
detection_interval=20,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize an incremental object tracker using GroundingDINO and SAM2.
|
||||||
|
Args:
|
||||||
|
grounding_model_id (str): HuggingFace model ID for GroundingDINO.
|
||||||
|
sam2_model_cfg (str): Path to SAM2 model config file.
|
||||||
|
sam2_ckpt_path (str): Path to SAM2 model checkpoint.
|
||||||
|
device (str): Device to run the models on ('cuda' or 'cpu').
|
||||||
|
prompt_text (str): Initial text prompt for detection.
|
||||||
|
detection_interval (int): Frame interval between full detections.
|
||||||
|
"""
|
||||||
|
self.device = device
|
||||||
|
self.detection_interval = detection_interval
|
||||||
|
self.prompt_text = prompt_text
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
self.grounding_predictor = GroundingDinoPredictor(
|
||||||
|
model_id=grounding_model_id, device=device
|
||||||
|
)
|
||||||
|
self.sam2_segmentor = SAM2ImageSegmentor(
|
||||||
|
sam_model_cfg=sam2_model_cfg,
|
||||||
|
sam_model_ckpt=sam2_ckpt_path,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.video_predictor = build_sam2_video_predictor(
|
||||||
|
sam2_model_cfg, sam2_ckpt_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize inference state
|
||||||
|
self.inference_state = self.video_predictor.init_state()
|
||||||
|
self.inference_state["images"] = torch.empty((0, 3, 1024, 1024), device=device)
|
||||||
|
self.total_frames = 0
|
||||||
|
self.objects_count = 0
|
||||||
|
self.frame_cache_limit = detection_interval - 1 # or higher depending on memory
|
||||||
|
|
||||||
|
# Store tracking results
|
||||||
|
self.last_mask_dict = MaskDictionaryModel()
|
||||||
|
self.track_dict = MaskDictionaryModel()
|
||||||
|
|
||||||
|
def add_image(self, image_np: np.ndarray):
|
||||||
|
"""
|
||||||
|
Add a new image frame to the tracker and perform detection or tracking update.
|
||||||
|
Args:
|
||||||
|
image_np (np.ndarray): Input RGB image as (H, W, 3), dtype=uint8.
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Annotated image with object masks and labels.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
img_pil = Image.fromarray(image_np)
|
||||||
|
|
||||||
|
# Step 1: Perform detection every detection_interval frames
|
||||||
|
if self.total_frames % self.detection_interval == 0:
|
||||||
|
if (
|
||||||
|
self.inference_state["video_height"] is None
|
||||||
|
or self.inference_state["video_width"] is None
|
||||||
|
):
|
||||||
|
(
|
||||||
|
self.inference_state["video_height"],
|
||||||
|
self.inference_state["video_width"],
|
||||||
|
) = image_np.shape[:2]
|
||||||
|
|
||||||
|
if self.inference_state["images"].shape[0] > self.frame_cache_limit:
|
||||||
|
print(
|
||||||
|
f"[Reset] Resetting inference state after {self.frame_cache_limit} frames to free memory."
|
||||||
|
)
|
||||||
|
self.inference_state = self.video_predictor.init_state()
|
||||||
|
self.inference_state["images"] = torch.empty(
|
||||||
|
(0, 3, 1024, 1024), device=self.device
|
||||||
|
)
|
||||||
|
(
|
||||||
|
self.inference_state["video_height"],
|
||||||
|
self.inference_state["video_width"],
|
||||||
|
) = image_np.shape[:2]
|
||||||
|
|
||||||
|
# 1.1 GroundingDINO object detection
|
||||||
|
boxes, labels = self.grounding_predictor.predict(img_pil, self.prompt_text)
|
||||||
|
if boxes.shape[0] == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1.2 SAM2 segmentation from detection boxes
|
||||||
|
self.sam2_segmentor.set_image(image_np)
|
||||||
|
masks, scores, logits = self.sam2_segmentor.predict_masks_from_boxes(boxes)
|
||||||
|
|
||||||
|
# 1.3 Build MaskDictionaryModel
|
||||||
|
mask_dict = MaskDictionaryModel(
|
||||||
|
promote_type="mask", mask_name=f"mask_{self.total_frames:05d}.npy"
|
||||||
|
)
|
||||||
|
mask_dict.add_new_frame_annotation(
|
||||||
|
mask_list=torch.tensor(masks).to(self.device),
|
||||||
|
box_list=torch.tensor(boxes),
|
||||||
|
label_list=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1.4 Object ID tracking and IOU-based update
|
||||||
|
self.objects_count = mask_dict.update_masks(
|
||||||
|
tracking_annotation_dict=self.last_mask_dict,
|
||||||
|
iou_threshold=0.3,
|
||||||
|
objects_count=self.objects_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1.5 Reset video tracker state
|
||||||
|
frame_idx = self.video_predictor.add_new_frame(
|
||||||
|
self.inference_state, image_np
|
||||||
|
)
|
||||||
|
self.video_predictor.reset_state(self.inference_state)
|
||||||
|
|
||||||
|
for object_id, object_info in mask_dict.labels.items():
|
||||||
|
frame_idx, _, _ = self.video_predictor.add_new_mask(
|
||||||
|
self.inference_state,
|
||||||
|
frame_idx,
|
||||||
|
object_id,
|
||||||
|
object_info.mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.track_dict = copy.deepcopy(mask_dict)
|
||||||
|
self.last_mask_dict = mask_dict
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Step 2: Use incremental tracking for intermediate frames
|
||||||
|
frame_idx = self.video_predictor.add_new_frame(
|
||||||
|
self.inference_state, image_np
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Tracking propagation using the video predictor
|
||||||
|
frame_idx, obj_ids, video_res_masks = self.video_predictor.infer_single_frame(
|
||||||
|
inference_state=self.inference_state,
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Update the mask dictionary based on tracked masks
|
||||||
|
frame_masks = MaskDictionaryModel()
|
||||||
|
for i, obj_id in enumerate(obj_ids):
|
||||||
|
out_mask = video_res_masks[i] > 0.0
|
||||||
|
object_info = ObjectInfo(
|
||||||
|
instance_id=obj_id,
|
||||||
|
mask=out_mask[0],
|
||||||
|
class_name=self.track_dict.get_target_class_name(obj_id),
|
||||||
|
logit=self.track_dict.get_target_logit(obj_id),
|
||||||
|
)
|
||||||
|
object_info.update_box()
|
||||||
|
frame_masks.labels[obj_id] = object_info
|
||||||
|
frame_masks.mask_name = f"mask_{frame_idx:05d}.npy"
|
||||||
|
frame_masks.mask_height = out_mask.shape[-2]
|
||||||
|
frame_masks.mask_width = out_mask.shape[-1]
|
||||||
|
|
||||||
|
self.last_mask_dict = copy.deepcopy(frame_masks)
|
||||||
|
|
||||||
|
# Step 5: Build mask array
|
||||||
|
H, W = image_np.shape[:2]
|
||||||
|
mask_img = torch.zeros((H, W), dtype=torch.int32)
|
||||||
|
for obj_id, obj_info in self.last_mask_dict.labels.items():
|
||||||
|
mask_img[obj_info.mask == True] = obj_id
|
||||||
|
|
||||||
|
mask_array = mask_img.cpu().numpy()
|
||||||
|
|
||||||
|
# Step 6: Visualization
|
||||||
|
annotated_frame = self.visualize_frame_with_mask_and_metadata(
|
||||||
|
image_np=image_np,
|
||||||
|
mask_array=mask_array,
|
||||||
|
json_metadata=self.last_mask_dict.to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"[Tracker] Total processed frames: {self.total_frames}")
|
||||||
|
self.total_frames += 1
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
def set_prompt(self, new_prompt: str):
|
||||||
|
"""
|
||||||
|
Dynamically update the GroundingDINO prompt and reset tracking state
|
||||||
|
to force a new object detection.
|
||||||
|
"""
|
||||||
|
self.prompt_text = new_prompt
|
||||||
|
self.total_frames = 0 # Trigger immediate re-detection
|
||||||
|
self.inference_state = self.video_predictor.init_state()
|
||||||
|
self.inference_state["images"] = torch.empty(
|
||||||
|
(0, 3, 1024, 1024), device=self.device
|
||||||
|
)
|
||||||
|
self.inference_state["video_height"] = None
|
||||||
|
self.inference_state["video_width"] = None
|
||||||
|
|
||||||
|
print(f"[Prompt Updated] New prompt: '{new_prompt}'. Tracker state reset.")
|
||||||
|
|
||||||
|
def save_current_state(self, output_dir, raw_image: np.ndarray = None):
|
||||||
|
"""
|
||||||
|
Save the current mask, metadata, raw image, and annotated result.
|
||||||
|
Args:
|
||||||
|
output_dir (str): The root output directory.
|
||||||
|
raw_image (np.ndarray, optional): The original input image (RGB).
|
||||||
|
"""
|
||||||
|
mask_data_dir = os.path.join(output_dir, "mask_data")
|
||||||
|
json_data_dir = os.path.join(output_dir, "json_data")
|
||||||
|
image_data_dir = os.path.join(output_dir, "images")
|
||||||
|
vis_data_dir = os.path.join(output_dir, "result")
|
||||||
|
|
||||||
|
os.makedirs(mask_data_dir, exist_ok=True)
|
||||||
|
os.makedirs(json_data_dir, exist_ok=True)
|
||||||
|
os.makedirs(image_data_dir, exist_ok=True)
|
||||||
|
os.makedirs(vis_data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
frame_masks = self.last_mask_dict
|
||||||
|
|
||||||
|
# Ensure mask_name is valid
|
||||||
|
if not frame_masks.mask_name or not frame_masks.mask_name.endswith(".npy"):
|
||||||
|
frame_masks.mask_name = f"mask_{self.total_frames:05d}.npy"
|
||||||
|
|
||||||
|
base_name = f"image_{self.total_frames:05d}"
|
||||||
|
|
||||||
|
# Save segmentation mask
|
||||||
|
mask_img = torch.zeros(frame_masks.mask_height, frame_masks.mask_width)
|
||||||
|
for obj_id, obj_info in frame_masks.labels.items():
|
||||||
|
mask_img[obj_info.mask == True] = obj_id
|
||||||
|
np.save(
|
||||||
|
os.path.join(mask_data_dir, frame_masks.mask_name),
|
||||||
|
mask_img.numpy().astype(np.uint16),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save metadata as JSON
|
||||||
|
json_path = os.path.join(json_data_dir, base_name + ".json")
|
||||||
|
frame_masks.to_json(json_path)
|
||||||
|
|
||||||
|
# Save raw input image
|
||||||
|
if raw_image is not None:
|
||||||
|
image_bgr = cv2.cvtColor(raw_image, cv2.COLOR_RGB2BGR)
|
||||||
|
cv2.imwrite(os.path.join(image_data_dir, base_name + ".jpg"), image_bgr)
|
||||||
|
|
||||||
|
# Save annotated image with mask, bounding boxes, and labels
|
||||||
|
annotated_image = self.visualize_frame_with_mask_and_metadata(
|
||||||
|
image_np=raw_image,
|
||||||
|
mask_array=mask_img.numpy().astype(np.uint16),
|
||||||
|
json_metadata=frame_masks.to_dict(),
|
||||||
|
)
|
||||||
|
annotated_bgr = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
|
||||||
|
cv2.imwrite(
|
||||||
|
os.path.join(vis_data_dir, base_name + "_annotated.jpg"), annotated_bgr
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[Saved] {base_name}.jpg and {base_name}_annotated.jpg saved successfully."
|
||||||
|
)
|
||||||
|
|
||||||
|
def visualize_frame_with_mask_and_metadata(
|
||||||
|
self,
|
||||||
|
image_np: np.ndarray,
|
||||||
|
mask_array: np.ndarray,
|
||||||
|
json_metadata: dict,
|
||||||
|
):
|
||||||
|
image = image_np.copy()
|
||||||
|
H, W = image.shape[:2]
|
||||||
|
|
||||||
|
# Step 1: Parse metadata and build object entries
|
||||||
|
metadata_lookup = json_metadata.get("labels", {})
|
||||||
|
|
||||||
|
all_object_ids = []
|
||||||
|
all_object_boxes = []
|
||||||
|
all_object_classes = []
|
||||||
|
all_object_masks = []
|
||||||
|
|
||||||
|
for obj_id_str, obj_info in metadata_lookup.items():
|
||||||
|
instance_id = obj_info.get("instance_id")
|
||||||
|
if instance_id is None or instance_id == 0:
|
||||||
|
continue
|
||||||
|
if instance_id not in np.unique(mask_array):
|
||||||
|
continue
|
||||||
|
|
||||||
|
object_mask = mask_array == instance_id
|
||||||
|
all_object_ids.append(instance_id)
|
||||||
|
x1 = obj_info.get("x1", 0)
|
||||||
|
y1 = obj_info.get("y1", 0)
|
||||||
|
x2 = obj_info.get("x2", 0)
|
||||||
|
y2 = obj_info.get("y2", 0)
|
||||||
|
all_object_boxes.append([x1, y1, x2, y2])
|
||||||
|
all_object_classes.append(obj_info.get("class_name", "unknown"))
|
||||||
|
all_object_masks.append(object_mask[None]) # Shape (1, H, W)
|
||||||
|
|
||||||
|
# Step 2: Check if valid objects exist
|
||||||
|
if len(all_object_ids) == 0:
|
||||||
|
print("No valid object instances found in metadata.")
|
||||||
|
return image
|
||||||
|
|
||||||
|
# Step 3: Sort by instance ID
|
||||||
|
paired = list(
|
||||||
|
zip(all_object_ids, all_object_boxes, all_object_masks, all_object_classes)
|
||||||
|
)
|
||||||
|
paired.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
all_object_ids = [p[0] for p in paired]
|
||||||
|
all_object_boxes = [p[1] for p in paired]
|
||||||
|
all_object_masks = [p[2] for p in paired]
|
||||||
|
all_object_classes = [p[3] for p in paired]
|
||||||
|
|
||||||
|
# Step 4: Build detections
|
||||||
|
all_object_masks = np.concatenate(all_object_masks, axis=0)
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=np.array(all_object_boxes),
|
||||||
|
mask=all_object_masks,
|
||||||
|
class_id=np.array(all_object_ids, dtype=np.int32),
|
||||||
|
)
|
||||||
|
labels = [
|
||||||
|
f"{instance_id}: {class_name}"
|
||||||
|
for instance_id, class_name in zip(all_object_ids, all_object_classes)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 5: Annotate image
|
||||||
|
annotated_frame = image.copy()
|
||||||
|
mask_annotator = sv.MaskAnnotator()
|
||||||
|
box_annotator = sv.BoxAnnotator()
|
||||||
|
label_annotator = sv.LabelAnnotator()
|
||||||
|
|
||||||
|
annotated_frame = mask_annotator.annotate(annotated_frame, detections)
|
||||||
|
annotated_frame = box_annotator.annotate(annotated_frame, detections)
|
||||||
|
annotated_frame = label_annotator.annotate(annotated_frame, detections, labels)
|
||||||
|
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from utils.common_utils import CommonUtils
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parameter settings
|
||||||
|
output_dir = "./outputs"
|
||||||
|
prompt_text = "hand."
|
||||||
|
detection_interval = 20
|
||||||
|
max_frames = 300 # Maximum number of frames to process (prevents infinite loop)
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize the object tracker
|
||||||
|
tracker = IncrementalObjectTracker(
|
||||||
|
grounding_model_id="IDEA-Research/grounding-dino-tiny",
|
||||||
|
sam2_model_cfg="configs/sam2.1/sam2.1_hiera_l.yaml",
|
||||||
|
sam2_ckpt_path="./checkpoints/sam2.1_hiera_large.pt",
|
||||||
|
device="cuda",
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
detection_interval=detection_interval,
|
||||||
|
)
|
||||||
|
tracker.set_prompt("person.")
|
||||||
|
|
||||||
|
# Open the camera (or replace with local video file, e.g., cv2.VideoCapture("video.mp4"))
|
||||||
|
cap = cv2.VideoCapture(0)
|
||||||
|
if not cap.isOpened():
|
||||||
|
print("[Error] Cannot open camera.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[Info] Camera opened. Press 'q' to quit.")
|
||||||
|
frame_idx = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print("[Warning] Failed to capture frame.")
|
||||||
|
break
|
||||||
|
|
||||||
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
print(f"[Frame {frame_idx}] Processing live frame...")
|
||||||
|
process_image = tracker.add_image(frame_rgb)
|
||||||
|
|
||||||
|
if process_image is None or not isinstance(process_image, np.ndarray):
|
||||||
|
print(f"[Warning] Skipped frame {frame_idx} due to empty result.")
|
||||||
|
frame_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# process_image_bgr = cv2.cvtColor(process_image, cv2.COLOR_RGB2BGR)
|
||||||
|
# cv2.imshow("Live Inference", process_image_bgr)
|
||||||
|
|
||||||
|
|
||||||
|
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
# print("[Info] Quit signal received.")
|
||||||
|
# break
|
||||||
|
|
||||||
|
tracker.save_current_state(output_dir=output_dir, raw_image=frame_rgb)
|
||||||
|
frame_idx += 1
|
||||||
|
|
||||||
|
if frame_idx >= max_frames:
|
||||||
|
print(f"[Info] Reached max_frames {max_frames}. Stopping.")
|
||||||
|
break
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("[Info] Interrupted by user (Ctrl+C).")
|
||||||
|
finally:
|
||||||
|
cap.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
print("[Done] Live inference complete.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@@ -1,9 +1,7 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for DINO-X - update to V2Task API
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk.tasks.dinox import DinoxTask
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||||
from dds_cloudapi_sdk.tasks.types import DetectionTarget
|
|
||||||
from dds_cloudapi_sdk import TextPrompt
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -30,6 +28,7 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
|||||||
API_TOKEN_FOR_DINOX = "Your API token"
|
API_TOKEN_FOR_DINOX = "Your API token"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
|
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 1: Environment settings and model initialization for SAM 2
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
@@ -98,22 +97,29 @@ config = Config(API_TOKEN_FOR_DINOX)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task using V2Task class
|
||||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = DinoxTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/dinox/detection",
|
||||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
api_body={
|
||||||
bbox_threshold=0.25,
|
"model": "DINO-X-1.0",
|
||||||
targets=[DetectionTarget.BBox],
|
"image": image_url,
|
||||||
|
"prompt": {
|
||||||
|
"type": "text",
|
||||||
|
"text": TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result.objects # the list of detected objects
|
objects = result["objects"] # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -121,9 +127,9 @@ confidences = []
|
|||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
class_names.append(obj.category)
|
class_names.append(obj["category"])
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
@@ -1,10 +1,7 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for Grounding DINO 1.5 - Update to V2Task API
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk import DetectionTask
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||||
from dds_cloudapi_sdk import TextPrompt
|
|
||||||
from dds_cloudapi_sdk import DetectionModel
|
|
||||||
from dds_cloudapi_sdk import DetectionTarget
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -31,6 +28,7 @@ SAVE_TRACKING_RESULTS_DIR = "./tracking_results"
|
|||||||
API_TOKEN_FOR_GD1_5 = "Your API token"
|
API_TOKEN_FOR_GD1_5 = "Your API token"
|
||||||
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
|
||||||
BOX_THRESHOLD = 0.2
|
BOX_THRESHOLD = 0.2
|
||||||
|
IOU_THRESHOLD = 0.8 # 添加IOU阈值参数
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 1: Environment settings and model initialization for SAM 2
|
Step 1: Environment settings and model initialization for SAM 2
|
||||||
@@ -99,33 +97,38 @@ config = Config(API_TOKEN_FOR_GD1_5)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task using V2Task class
|
||||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = DetectionTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/grounding_dino/detection",
|
||||||
prompts=[TextPrompt(text=TEXT_PROMPT)],
|
api_body={
|
||||||
targets=[DetectionTarget.BBox], # detect bbox
|
"model": "GroundingDino-1.5-Pro",
|
||||||
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
"image": image_url,
|
||||||
bbox_threshold=BOX_THRESHOLD,
|
"prompt": {
|
||||||
|
"type": "text",
|
||||||
|
"text": TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result.objects # the list of detected objects
|
objects = result["objects"] # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
class_names.append(obj.category)
|
class_names.append(obj["category"])
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
@@ -1,11 +1,7 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk import DetectionTask
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||||
from dds_cloudapi_sdk import TextPrompt
|
|
||||||
from dds_cloudapi_sdk import DetectionModel
|
|
||||||
from dds_cloudapi_sdk import DetectionTarget
|
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
@@ -51,6 +47,9 @@ grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).
|
|||||||
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
# setup the input image and text prompt for SAM 2 and Grounding DINO
|
||||||
# VERY important: text queries need to be lowercased + end with a dot
|
# VERY important: text queries need to be lowercased + end with a dot
|
||||||
text = "car."
|
text = "car."
|
||||||
|
BOX_THRESHOLD = 0.2
|
||||||
|
IOU_THRESHOLD = 0.8
|
||||||
|
GROUNDING_MODEL = "GroundingDino-1.6-Pro" # 使用字符串替代枚举值
|
||||||
|
|
||||||
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
||||||
video_dir = "notebooks/videos/car"
|
video_dir = "notebooks/videos/car"
|
||||||
@@ -102,24 +101,32 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
task = DetectionTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/grounding_dino/detection",
|
||||||
prompts=[TextPrompt(text=text)],
|
api_body={
|
||||||
targets=[DetectionTarget.BBox], # detect bbox
|
"model": GROUNDING_MODEL,
|
||||||
model=DetectionModel.GDino1_6_Pro, # detect with GroundingDino-1.5-Pro model
|
"image": image_url,
|
||||||
|
"prompt": {
|
||||||
|
"type": "text",
|
||||||
|
"text": text
|
||||||
|
},
|
||||||
|
"targets": ["bbox"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result.objects # the list of detected objects
|
objects = result["objects"] # the list of detected objects
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
confidences = []
|
confidences = []
|
||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
class_names.append(obj.category)
|
class_names.append(obj["category"])
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
OBJECTS = class_names
|
OBJECTS = class_names
|
||||||
@@ -154,7 +161,7 @@ for start_frame_idx in range(0, len(frame_names), step):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=0.8, objects_count=objects_count)
|
objects_count = mask_dict.update_masks(tracking_annotation_dict=sam2_masks, iou_threshold=IOU_THRESHOLD, objects_count=objects_count)
|
||||||
print("objects_count", objects_count)
|
print("objects_count", objects_count)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
@@ -1,10 +1,7 @@
|
|||||||
# dds cloudapi for Grounding DINO 1.5
|
# dds cloudapi for Grounding DINO 1.5 - update to V2Task API
|
||||||
from dds_cloudapi_sdk import Config
|
from dds_cloudapi_sdk import Config
|
||||||
from dds_cloudapi_sdk import Client
|
from dds_cloudapi_sdk import Client
|
||||||
from dds_cloudapi_sdk import DetectionTask
|
from dds_cloudapi_sdk.tasks.v2_task import V2Task
|
||||||
from dds_cloudapi_sdk import TextPrompt
|
|
||||||
from dds_cloudapi_sdk import DetectionModel
|
|
||||||
from dds_cloudapi_sdk import DetectionTarget
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import cv2
|
import cv2
|
||||||
@@ -54,6 +51,11 @@ inference_state = video_predictor.init_state(video_path=video_dir)
|
|||||||
ann_frame_idx = 0 # the frame index we interact with
|
ann_frame_idx = 0 # the frame index we interact with
|
||||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||||
|
|
||||||
|
# 添加参数设置
|
||||||
|
TEXT_PROMPT = "children. pillow"
|
||||||
|
BOX_THRESHOLD = 0.2
|
||||||
|
IOU_THRESHOLD = 0.8
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
|
||||||
@@ -70,23 +72,29 @@ config = Config(token)
|
|||||||
# Step 2: initialize the client
|
# Step 2: initialize the client
|
||||||
client = Client(config)
|
client = Client(config)
|
||||||
|
|
||||||
# Step 3: run the task by DetectionTask class
|
# Step 3: run the task using V2Task class
|
||||||
# image_url = "https://algosplt.oss-cn-shenzhen.aliyuncs.com/test_files/tasks/detection/iron_man.jpg"
|
|
||||||
# if you are processing local image file, upload them to DDS server to get the image url
|
# if you are processing local image file, upload them to DDS server to get the image url
|
||||||
image_url = client.upload_file(img_path)
|
image_url = client.upload_file(img_path)
|
||||||
|
|
||||||
task = DetectionTask(
|
task = V2Task(
|
||||||
image_url=image_url,
|
api_path="/v2/task/grounding_dino/detection",
|
||||||
prompts=[TextPrompt(text="children. pillow")],
|
api_body={
|
||||||
targets=[DetectionTarget.BBox], # detect bbox
|
"model": "GroundingDino-1.5-Pro",
|
||||||
model=DetectionModel.GDino1_5_Pro, # detect with GroundingDino-1.5-Pro model
|
"image": image_url,
|
||||||
bbox_threshold=0.2,
|
"prompt": {
|
||||||
|
"type": "text",
|
||||||
|
"text": TEXT_PROMPT
|
||||||
|
},
|
||||||
|
"targets": ["bbox"],
|
||||||
|
"bbox_threshold": BOX_THRESHOLD,
|
||||||
|
"iou_threshold": IOU_THRESHOLD,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
client.run_task(task)
|
client.run_task(task)
|
||||||
result = task.result
|
result = task.result
|
||||||
|
|
||||||
objects = result.objects # the list of detected objects
|
objects = result["objects"] # the list of detected objects
|
||||||
|
|
||||||
|
|
||||||
input_boxes = []
|
input_boxes = []
|
||||||
@@ -94,9 +102,9 @@ confidences = []
|
|||||||
class_names = []
|
class_names = []
|
||||||
|
|
||||||
for idx, obj in enumerate(objects):
|
for idx, obj in enumerate(objects):
|
||||||
input_boxes.append(obj.bbox)
|
input_boxes.append(obj["bbox"])
|
||||||
confidences.append(obj.score)
|
confidences.append(obj["score"])
|
||||||
class_names.append(obj.category)
|
class_names.append(obj["category"])
|
||||||
|
|
||||||
input_boxes = np.array(input_boxes)
|
input_boxes = np.array(input_boxes)
|
||||||
|
|
||||||
|
@@ -16,7 +16,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.checkpoint as checkpoint
|
import torch.utils.checkpoint as checkpoint
|
||||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
||||||
|
|
||||||
from grounding_dino.groundingdino.util.misc import NestedTensor
|
from grounding_dino.groundingdino.util.misc import NestedTensor
|
||||||
|
|
||||||
@@ -113,7 +113,7 @@ class WindowAttention(nn.Module):
|
|||||||
# get pair-wise relative position index for each token inside the window
|
# get pair-wise relative position index for each token inside the window
|
||||||
coords_h = torch.arange(self.window_size[0])
|
coords_h = torch.arange(self.window_size[0])
|
||||||
coords_w = torch.arange(self.window_size[1])
|
coords_w = torch.arange(self.window_size[1])
|
||||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
||||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||||
|
@@ -15,6 +15,19 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <torch/version.h>
|
||||||
|
|
||||||
|
// Check PyTorch version and define appropriate macros
|
||||||
|
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 6
|
||||||
|
// PyTorch 2.x and above
|
||||||
|
#define GET_TENSOR_TYPE(x) x.scalar_type()
|
||||||
|
#define IS_CUDA_TENSOR(x) x.device().is_cuda()
|
||||||
|
#else
|
||||||
|
// PyTorch 1.x
|
||||||
|
#define GET_TENSOR_TYPE(x) x.type()
|
||||||
|
#define IS_CUDA_TENSOR(x) x.type().is_cuda()
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace groundingdino {
|
namespace groundingdino {
|
||||||
|
|
||||||
@@ -32,11 +45,11 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
|
|
||||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
|
||||||
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
|
||||||
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
|
||||||
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
|
||||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
|
||||||
|
|
||||||
const int batch = value.size(0);
|
const int batch = value.size(0);
|
||||||
const int spatial_size = value.size(1);
|
const int spatial_size = value.size(1);
|
||||||
@@ -62,7 +75,7 @@ at::Tensor ms_deform_attn_cuda_forward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto columns = output_n.select(0, n);
|
auto columns = output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_forward_cuda", ([&] {
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
spatial_shapes.data<int64_t>(),
|
spatial_shapes.data<int64_t>(),
|
||||||
@@ -98,12 +111,12 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
||||||
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
||||||
|
|
||||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(value), "value must be a CUDA tensor");
|
||||||
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(spatial_shapes), "spatial_shapes must be a CUDA tensor");
|
||||||
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(level_start_index), "level_start_index must be a CUDA tensor");
|
||||||
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(sampling_loc), "sampling_loc must be a CUDA tensor");
|
||||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(attn_weight), "attn_weight must be a CUDA tensor");
|
||||||
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
AT_ASSERTM(IS_CUDA_TENSOR(grad_output), "grad_output must be a CUDA tensor");
|
||||||
|
|
||||||
const int batch = value.size(0);
|
const int batch = value.size(0);
|
||||||
const int spatial_size = value.size(1);
|
const int spatial_size = value.size(1);
|
||||||
@@ -132,7 +145,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
for (int n = 0; n < batch/im2col_step_; ++n)
|
||||||
{
|
{
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
auto grad_output_g = grad_output_n.select(0, n);
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(GET_TENSOR_TYPE(value), "ms_deform_attn_backward_cuda", ([&] {
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
||||||
grad_output_g.data<scalar_t>(),
|
grad_output_g.data<scalar_t>(),
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
||||||
|
@@ -8,7 +8,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from timm.models.layers import DropPath
|
from timm.layers import DropPath
|
||||||
|
|
||||||
|
|
||||||
class FeatureResizer(nn.Module):
|
class FeatureResizer(nn.Module):
|
||||||
|
@@ -470,6 +470,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
ref_y, ref_x = torch.meshgrid(
|
ref_y, ref_x = torch.meshgrid(
|
||||||
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
||||||
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
||||||
|
indexing="ij"
|
||||||
)
|
)
|
||||||
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
|
||||||
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
|
||||||
@@ -859,7 +860,7 @@ class DeformableTransformerDecoderLayer(nn.Module):
|
|||||||
return tensor if pos is None else tensor + pos
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
def forward_ffn(self, tgt):
|
def forward_ffn(self, tgt):
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.amp.autocast("cuda", enabled=False):
|
||||||
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
|
||||||
tgt = tgt + self.dropout4(tgt2)
|
tgt = tgt + self.dropout4(tgt2)
|
||||||
tgt = self.norm3(tgt)
|
tgt = self.norm3(tgt)
|
||||||
|
@@ -79,6 +79,7 @@ def gen_encoder_output_proposals(
|
|||||||
grid_y, grid_x = torch.meshgrid(
|
grid_y, grid_x = torch.meshgrid(
|
||||||
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
||||||
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
||||||
|
indexing="ij"
|
||||||
)
|
)
|
||||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2
|
||||||
|
|
||||||
|
@@ -118,7 +118,7 @@ def masks_to_boxes(masks):
|
|||||||
|
|
||||||
y = torch.arange(0, h, dtype=torch.float)
|
y = torch.arange(0, h, dtype=torch.float)
|
||||||
x = torch.arange(0, w, dtype=torch.float)
|
x = torch.arange(0, w, dtype=torch.float)
|
||||||
y, x = torch.meshgrid(y, x)
|
y, x = torch.meshgrid(y, x, indexing="ij")
|
||||||
|
|
||||||
x_mask = masks * x.unsqueeze(0)
|
x_mask = masks * x.unsqueeze(0)
|
||||||
x_max = x_mask.flatten(1).max(-1)[0]
|
x_max = x_mask.flatten(1).max(-1)[0]
|
||||||
|
@@ -63,6 +63,7 @@ def predict(
|
|||||||
|
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
image = image.to(device)
|
image = image.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(image[None], captions=[caption])
|
outputs = model(image[None], captions=[caption])
|
||||||
|
@@ -1,6 +1,68 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = [
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
"setuptools>=61.0",
|
|
||||||
"torch>=2.3.1",
|
|
||||||
]
|
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "Grounded-SAM-2"
|
||||||
|
version = "1.0"
|
||||||
|
description = "Grounded SAM 2: Ground and Track Anything in Videos"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10.0"
|
||||||
|
license = { text = "Apache 2.0" }
|
||||||
|
authors = [{ name = "Meta AI", email = "segment-anything@meta.com" }]
|
||||||
|
keywords = ["segmentation", "computer vision", "deep learning"]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"torch>=2.3.1",
|
||||||
|
"torchvision>=0.18.1",
|
||||||
|
"numpy>=1.24.4",
|
||||||
|
"tqdm>=4.66.1",
|
||||||
|
"hydra-core>=1.3.2",
|
||||||
|
"iopath>=0.1.10",
|
||||||
|
"pillow>=9.4.0",
|
||||||
|
"opencv-python-headless>=4.11.0.86",
|
||||||
|
"supervision>=0.26.1",
|
||||||
|
"pycocotools>=2.0.10",
|
||||||
|
"transformers>=4.55.1",
|
||||||
|
"addict>=2.4.0",
|
||||||
|
"yapf>=0.43.0",
|
||||||
|
"timm>=1.0.19",
|
||||||
|
"pdf2image>=1.17.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
notebooks = [
|
||||||
|
"matplotlib>=3.9.1",
|
||||||
|
"jupyter>=1.0.0",
|
||||||
|
"opencv-python>=4.7.0",
|
||||||
|
"eva-decord>=0.6.1",
|
||||||
|
]
|
||||||
|
interactive-demo = [
|
||||||
|
"Flask>=3.0.3",
|
||||||
|
"Flask-Cors>=5.0.0",
|
||||||
|
"av>=13.0.0",
|
||||||
|
"dataclasses-json>=0.6.7",
|
||||||
|
"eva-decord>=0.6.1",
|
||||||
|
"gunicorn>=23.0.0",
|
||||||
|
"imagesize>=1.4.1",
|
||||||
|
"pycocotools>=2.0.8",
|
||||||
|
"strawberry-graphql>=0.243.0",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"black==24.2.0",
|
||||||
|
"usort==1.0.2",
|
||||||
|
"ufmt==2.0.0b2",
|
||||||
|
"fvcore>=0.1.5.post20221221",
|
||||||
|
"pandas>=2.2.2",
|
||||||
|
"scikit-image>=0.24.0",
|
||||||
|
"tensorboard>=2.17.0",
|
||||||
|
"pycocotools>=2.0.8",
|
||||||
|
"tensordict>=0.5.0",
|
||||||
|
"opencv-python>=4.7.0",
|
||||||
|
"submitit>=1.5.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
# extensions = [{ name = "sam2._C", sources = ["sam2/csrc/connected_components.cu"] }]
|
||||||
|
packages = ["sam2", "grounding_dino"]
|
||||||
|
@@ -12,7 +12,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
||||||
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames ,process_stream_frame
|
||||||
|
|
||||||
|
|
||||||
class SAM2VideoPredictor(SAM2Base):
|
class SAM2VideoPredictor(SAM2Base):
|
||||||
@@ -43,23 +43,33 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def init_state(
|
def init_state(
|
||||||
self,
|
self,
|
||||||
video_path,
|
video_path=None,
|
||||||
offload_video_to_cpu=False,
|
offload_video_to_cpu=False,
|
||||||
offload_state_to_cpu=False,
|
offload_state_to_cpu=False,
|
||||||
async_loading_frames=False,
|
async_loading_frames=False,
|
||||||
):
|
):
|
||||||
"""Initialize an inference state."""
|
"""Initialize an inference state."""
|
||||||
compute_device = self.device # device of the model
|
compute_device = self.device # device of the model
|
||||||
images, video_height, video_width = load_video_frames(
|
|
||||||
video_path=video_path,
|
|
||||||
image_size=self.image_size,
|
|
||||||
offload_video_to_cpu=offload_video_to_cpu,
|
|
||||||
async_loading_frames=async_loading_frames,
|
|
||||||
compute_device=compute_device,
|
|
||||||
)
|
|
||||||
inference_state = {}
|
inference_state = {}
|
||||||
inference_state["images"] = images
|
if video_path is not None:
|
||||||
inference_state["num_frames"] = len(images)
|
# Preload video frames from file
|
||||||
|
images, video_height, video_width = load_video_frames(
|
||||||
|
video_path=video_path,
|
||||||
|
image_size=self.image_size,
|
||||||
|
offload_video_to_cpu=offload_video_to_cpu,
|
||||||
|
async_loading_frames=async_loading_frames,
|
||||||
|
compute_device=compute_device,
|
||||||
|
)
|
||||||
|
inference_state["images"] = images
|
||||||
|
inference_state["num_frames"] = len(images)
|
||||||
|
else:
|
||||||
|
# Real-time streaming mode
|
||||||
|
print("Real-time streaming mode: waiting for first image input...")
|
||||||
|
images = None
|
||||||
|
video_height, video_width = None, None
|
||||||
|
inference_state["images"] = None
|
||||||
|
inference_state["num_frames"] = 0
|
||||||
|
|
||||||
# whether to offload the video frames to CPU memory
|
# whether to offload the video frames to CPU memory
|
||||||
# turning on this option saves the GPU memory with only a very small overhead
|
# turning on this option saves the GPU memory with only a very small overhead
|
||||||
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
||||||
@@ -107,7 +117,9 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state["tracking_has_started"] = False
|
inference_state["tracking_has_started"] = False
|
||||||
inference_state["frames_already_tracked"] = {}
|
inference_state["frames_already_tracked"] = {}
|
||||||
# Warm up the visual backbone and cache the image feature on frame 0
|
# Warm up the visual backbone and cache the image feature on frame 0
|
||||||
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
if video_path is not None:
|
||||||
|
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
||||||
|
|
||||||
return inference_state
|
return inference_state
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -743,6 +755,133 @@ class SAM2VideoPredictor(SAM2Base):
|
|||||||
inference_state, pred_masks
|
inference_state, pred_masks
|
||||||
)
|
)
|
||||||
yield frame_idx, obj_ids, video_res_masks
|
yield frame_idx, obj_ids, video_res_masks
|
||||||
|
@torch.inference_mode()
|
||||||
|
def add_new_frame(self, inference_state, new_image):
|
||||||
|
"""
|
||||||
|
Add a new frame to the inference state and cache its image features.
|
||||||
|
Args:
|
||||||
|
inference_state (dict): The current inference state containing cached frames, features, and tracking information.
|
||||||
|
new_image (Tensor or ndarray): The input image frame (in HWC or CHW format depending on upstream processing).
|
||||||
|
Returns:
|
||||||
|
frame_idx (int): The index of the newly added frame within the inference state.
|
||||||
|
"""
|
||||||
|
device = inference_state["device"]
|
||||||
|
|
||||||
|
# Preprocess the input frame and convert it to a normalized tensor
|
||||||
|
img_tensor, orig_h, orig_w = process_stream_frame(
|
||||||
|
img_array=new_image,
|
||||||
|
image_size=self.image_size,
|
||||||
|
offload_to_cpu=False,
|
||||||
|
compute_device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle initialization of the image sequence if this is the first frame
|
||||||
|
images = inference_state.get("images", None)
|
||||||
|
if images is None or (isinstance(images, list) and len(images) == 0):
|
||||||
|
# First frame: initialize image tensor batch
|
||||||
|
inference_state["images"] = img_tensor.unsqueeze(0) # Shape: [1, C, H, W]
|
||||||
|
else:
|
||||||
|
# Append to existing tensor batch
|
||||||
|
if isinstance(images, list):
|
||||||
|
raise ValueError(
|
||||||
|
"inference_state['images'] should be a Tensor, not a list after initialization."
|
||||||
|
)
|
||||||
|
|
||||||
|
img_tensor = img_tensor.to(images.device)
|
||||||
|
inference_state["images"] = torch.cat(
|
||||||
|
[images, img_tensor.unsqueeze(0)], dim=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update frame count and compute new frame index
|
||||||
|
inference_state["num_frames"] = inference_state["images"].shape[0]
|
||||||
|
frame_idx = inference_state["num_frames"] - 1
|
||||||
|
|
||||||
|
# Cache visual features for the newly added frame
|
||||||
|
image_batch = img_tensor.float().unsqueeze(0) # Shape: [1, C, H, W]
|
||||||
|
backbone_out = self.forward_image(image_batch)
|
||||||
|
inference_state["cached_features"][frame_idx] = (image_batch, backbone_out)
|
||||||
|
|
||||||
|
return frame_idx
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def infer_single_frame(self, inference_state, frame_idx):
|
||||||
|
"""
|
||||||
|
Run inference on a single frame using existing points/masks in the inference state.
|
||||||
|
Args:
|
||||||
|
inference_state (dict): The current state of the tracking process.
|
||||||
|
frame_idx (int): Index of the frame to run inference on.
|
||||||
|
Returns:
|
||||||
|
frame_idx (int): Same as input; the index of the processed frame.
|
||||||
|
obj_ids (list): List of currently tracked object IDs.
|
||||||
|
video_res_masks (Tensor): Segmentation masks predicted for the objects in the frame.
|
||||||
|
"""
|
||||||
|
if frame_idx >= inference_state["num_frames"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Frame index {frame_idx} out of range (num_frames={inference_state['num_frames']})."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.propagate_in_video_preflight(inference_state)
|
||||||
|
|
||||||
|
output_dict = inference_state["output_dict"]
|
||||||
|
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
||||||
|
batch_size = self._get_obj_num(inference_state)
|
||||||
|
|
||||||
|
# Ensure that initial conditioning points exist
|
||||||
|
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No conditioning points provided. Please add points before inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decide whether to clear nearby memory based on number of objects
|
||||||
|
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
||||||
|
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
||||||
|
)
|
||||||
|
|
||||||
|
obj_ids = inference_state["obj_ids"]
|
||||||
|
|
||||||
|
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||||
|
# If output is already consolidated with conditioning inputs
|
||||||
|
storage_key = "cond_frame_outputs"
|
||||||
|
current_out = output_dict[storage_key][frame_idx]
|
||||||
|
pred_masks = current_out["pred_masks"]
|
||||||
|
|
||||||
|
if clear_non_cond_mem:
|
||||||
|
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
||||||
|
|
||||||
|
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||||
|
# If output was inferred without conditioning
|
||||||
|
storage_key = "non_cond_frame_outputs"
|
||||||
|
current_out = output_dict[storage_key][frame_idx]
|
||||||
|
pred_masks = current_out["pred_masks"]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Run model inference for this frame
|
||||||
|
storage_key = "non_cond_frame_outputs"
|
||||||
|
current_out, pred_masks = self._run_single_frame_inference(
|
||||||
|
inference_state=inference_state,
|
||||||
|
output_dict=output_dict,
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_init_cond_frame=False,
|
||||||
|
point_inputs=None,
|
||||||
|
mask_inputs=None,
|
||||||
|
reverse=False,
|
||||||
|
run_mem_encoder=True,
|
||||||
|
)
|
||||||
|
output_dict[storage_key][frame_idx] = current_out
|
||||||
|
|
||||||
|
# Organize per-object outputs and mark frame as tracked
|
||||||
|
self._add_output_per_object(
|
||||||
|
inference_state, frame_idx, current_out, storage_key
|
||||||
|
)
|
||||||
|
inference_state["frames_already_tracked"][frame_idx] = {"reverse": False}
|
||||||
|
|
||||||
|
# Convert output to original video resolution
|
||||||
|
_, video_res_masks = self._get_orig_video_res_output(
|
||||||
|
inference_state, pred_masks
|
||||||
|
)
|
||||||
|
|
||||||
|
return frame_idx, obj_ids, video_res_masks
|
||||||
|
|
||||||
def _add_output_per_object(
|
def _add_output_per_object(
|
||||||
self, inference_state, frame_idx, current_out, storage_key
|
self, inference_state, frame_idx, current_out, storage_key
|
||||||
|
@@ -8,6 +8,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -209,6 +210,74 @@ def load_video_frames(
|
|||||||
"Only MP4 video and JPEG folder are supported at this moment"
|
"Only MP4 video and JPEG folder are supported at this moment"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_stream_frame(
|
||||||
|
img_array: np.ndarray,
|
||||||
|
image_size: int,
|
||||||
|
img_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
|
||||||
|
img_std: Tuple[float, float, float] = (0.229, 0.224, 0.225),
|
||||||
|
offload_to_cpu: bool = False,
|
||||||
|
compute_device: torch.device = torch.device("cuda"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Convert a raw image array (H,W,3 or 3,H,W) into a model‑ready tensor.
|
||||||
|
Steps
|
||||||
|
-----
|
||||||
|
1. Resize the shorter side to `image_size`, keeping aspect ratio,
|
||||||
|
then center‑crop/pad to `image_size` × `image_size`.
|
||||||
|
2. Change layout to [3, H, W] and cast to float32 in [0,1].
|
||||||
|
3. Normalise with ImageNet statistics.
|
||||||
|
4. Optionally move to `compute_device`.
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
img_tensor : torch.FloatTensor # shape [3, image_size, image_size]
|
||||||
|
orig_h : int
|
||||||
|
orig_w : int
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ↪ uses your existing helper so behaviour matches the batch loader
|
||||||
|
img_tensor, orig_h, orig_w = _resize_and_convert_to_tensor(img_array, image_size)
|
||||||
|
|
||||||
|
# Normalisation (done *after* potential device move for efficiency)
|
||||||
|
img_mean_t = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
|
||||||
|
img_std_t = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
|
||||||
|
|
||||||
|
if not offload_to_cpu:
|
||||||
|
img_tensor = img_tensor.to(compute_device)
|
||||||
|
img_mean_t = img_mean_t.to(compute_device)
|
||||||
|
img_std_t = img_std_t.to(compute_device)
|
||||||
|
|
||||||
|
img_tensor.sub_(img_mean_t).div_(img_std_t)
|
||||||
|
|
||||||
|
return img_tensor, orig_h, orig_w
|
||||||
|
|
||||||
|
|
||||||
|
def _resize_and_convert_to_tensor(img_array, image_size):
|
||||||
|
"""
|
||||||
|
Resize the input image array and convert it into a tensor.
|
||||||
|
Also return original image height and width.
|
||||||
|
"""
|
||||||
|
# Convert numpy array to PIL image and ensure RGB
|
||||||
|
img_pil = Image.fromarray(img_array).convert("RGB")
|
||||||
|
|
||||||
|
# Save original size (PIL: size = (width, height))
|
||||||
|
video_width, video_height = img_pil.size
|
||||||
|
|
||||||
|
# Resize with high-quality LANCZOS filter
|
||||||
|
img_resized = img_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# Convert resized image back to numpy and then to float tensor
|
||||||
|
img_resized_array = np.array(img_resized)
|
||||||
|
|
||||||
|
if img_resized_array.dtype == np.uint8:
|
||||||
|
img_resized_array = img_resized_array / 255.0
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected dtype: {img_resized_array.dtype}")
|
||||||
|
|
||||||
|
# Convert to PyTorch tensor and permute to [C, H, W]
|
||||||
|
img_tensor = torch.from_numpy(img_resized_array).permute(2, 0, 1)
|
||||||
|
|
||||||
|
return img_tensor, video_height, video_width
|
||||||
|
|
||||||
|
|
||||||
def load_video_frames_from_jpg_images(
|
def load_video_frames_from_jpg_images(
|
||||||
video_path,
|
video_path,
|
||||||
|
@@ -623,7 +623,7 @@ class Trainer:
|
|||||||
|
|
||||||
# compute output
|
# compute output
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(
|
with torch.amp.autocast("cuda",
|
||||||
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
|
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
|
||||||
dtype=(
|
dtype=(
|
||||||
get_amp_type(self.optim_conf.amp.amp_dtype)
|
get_amp_type(self.optim_conf.amp.amp_dtype)
|
||||||
@@ -858,7 +858,8 @@ class Trainer:
|
|||||||
# grads will also update a model even if the step doesn't produce
|
# grads will also update a model even if the step doesn't produce
|
||||||
# gradients
|
# gradients
|
||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
with torch.cuda.amp.autocast(
|
with torch.amp.autocast(
|
||||||
|
"cuda",
|
||||||
enabled=self.optim_conf.amp.enabled,
|
enabled=self.optim_conf.amp.enabled,
|
||||||
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
|
dtype=get_amp_type(self.optim_conf.amp.amp_dtype),
|
||||||
):
|
):
|
||||||
|
Reference in New Issue
Block a user