diff --git a/README.md b/README.md
index 349dd68..45634ef 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,10 @@
# Grounding DINO
+[📃Paper](https://arxiv.org/abs/2303.05499) |
+[📽️Video](https://www.youtube.com/watch?v=wxWDt5UiwY8) |
+[📯Demo on Colab](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) |
+[🤗Demo on HF (Coming soon)]()
----
-
-[](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
-
+[](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) \
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
@@ -20,8 +21,10 @@ Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.0
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
-
+## News
+[2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
+[2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \
+[2023/03/22] Code is available Now!
@@ -30,15 +33,18 @@ Description
+
+
## TODO
- [x] Release inference code and demo.
- [x] Release checkpoints.
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
+- [ ] Release training codes.
## Install
-If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
+If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
```bash
pip install -e .
@@ -46,15 +52,16 @@ pip install -e .
## Demo
-See the `demo/inference_on_a_image.py` for more details.
```bash
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
-c /path/to/config \
-p /path/to/checkpoint \
-i .asset/cats.png \
-o "outputs/0" \
- -t "cat ear."
+ -t "cat ear." \
+ [--cpu-only] # open it for cpu mode
```
+See the `demo/inference_on_a_image.py` for more details.
## Checkpoints
@@ -68,6 +75,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
Data |
box AP on COCO |
Checkpoint |
+ Config |
@@ -78,6 +86,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
O365,GoldG,Cap4M |
48.4 (zero-shot) / 57.2 (fine-tune) |
link |
+ link |
diff --git a/demo/inference_on_a_image.py b/demo/inference_on_a_image.py
index 79406ae..8c91899 100644
--- a/demo/inference_on_a_image.py
+++ b/demo/inference_on_a_image.py
@@ -39,7 +39,13 @@ def plot_boxes_to_image(image_pil, tgt):
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
# draw.text((x0, y0), str(label), fill=color)
- bbox = draw.textbbox((x0, y0), str(label))
+ font = ImageFont.load_default()
+ if hasattr(font, "getbbox"):
+ bbox = draw.textbbox((x0, y0), str(label), font)
+ else:
+ w, h = draw.textsize(str(label), font)
+ bbox = (x0, y0, w + x0, y0 + h)
+ # bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
draw.text((x0, y0), str(label), fill="white")
@@ -63,9 +69,9 @@ def load_image(image_path):
return image_pil, image
-def load_model(model_config_path, model_checkpoint_path):
+def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
args = SLConfig.fromfile(model_config_path)
- args.device = "cuda"
+ args.device = "cuda" if not cpu_only else "cpu"
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
@@ -74,13 +80,14 @@ def load_model(model_config_path, model_checkpoint_path):
return model
-def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
+def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
- model = model.cuda()
- image = image.cuda()
+ device = "cuda" if not cpu_only else "cpu"
+ model = model.to(device)
+ image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
@@ -125,6 +132,8 @@ if __name__ == "__main__":
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
+
+ parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
args = parser.parse_args()
# cfg
@@ -141,14 +150,14 @@ if __name__ == "__main__":
# load image
image_pil, image = load_image(image_path)
# load model
- model = load_model(config_file, checkpoint_path)
+ model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
# run model
boxes_filt, pred_phrases = get_grounding_output(
- model, image, text_prompt, box_threshold, text_threshold
+ model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
)
# visualize pred
diff --git a/groundingdino/models/GroundingDINO/backbone/position_encoding.py b/groundingdino/models/GroundingDINO/backbone/position_encoding.py
index 14b429c..eac7e89 100644
--- a/groundingdino/models/GroundingDINO/backbone/position_encoding.py
+++ b/groundingdino/models/GroundingDINO/backbone/position_encoding.py
@@ -111,11 +111,11 @@ class PositionEmbeddingSineHW(nn.Module):
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
+ dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_tx
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
- dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
+ dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
pos_y = y_embed[:, :, :, None] / dim_ty
pos_x = torch.stack(
diff --git a/groundingdino/models/GroundingDINO/ms_deform_attn.py b/groundingdino/models/GroundingDINO/ms_deform_attn.py
index a51d7d2..489d501 100644
--- a/groundingdino/models/GroundingDINO/ms_deform_attn.py
+++ b/groundingdino/models/GroundingDINO/ms_deform_attn.py
@@ -25,7 +25,10 @@ from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.init import constant_, xavier_uniform_
-from groundingdino import _C
+try:
+ from groundingdino import _C
+except:
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
# helpers
@@ -323,6 +326,7 @@ class MultiScaleDeformableAttention(nn.Module):
reference_points.shape[-1]
)
)
+
if torch.cuda.is_available() and value.is_cuda:
halffloat = False
if value.dtype == torch.float16:
diff --git a/groundingdino/models/GroundingDINO/utils.py b/groundingdino/models/GroundingDINO/utils.py
index caf0f1b..5bd18f7 100644
--- a/groundingdino/models/GroundingDINO/utils.py
+++ b/groundingdino/models/GroundingDINO/utils.py
@@ -206,7 +206,7 @@ def gen_sineembed_for_position(pos_tensor):
# sineembed_tensor = torch.zeros(n_query, bs, 256)
scale = 2 * math.pi
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
- dim_t = 10000 ** (2 * (dim_t // 2) / 128)
+ dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
x_embed = pos_tensor[:, :, 0] * scale
y_embed = pos_tensor[:, :, 1] * scale
pos_x = x_embed[:, :, None] / dim_t
diff --git a/requirements.txt b/requirements.txt
index 63773b7..f52ed0a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,5 @@ yapf
timm
numpy
opencv-python
-supervision==0.3.2
\ No newline at end of file
+supervision==0.3.2
+pycocotools
\ No newline at end of file