1. fix warnings. \n 2. support CPU mode. \n 3. update README.
This commit is contained in:
27
README.md
27
README.md
@@ -1,9 +1,10 @@
|
|||||||
# Grounding DINO
|
# Grounding DINO
|
||||||
|
[📃Paper](https://arxiv.org/abs/2303.05499) |
|
||||||
|
[📽️Video](https://www.youtube.com/watch?v=wxWDt5UiwY8) |
|
||||||
|
[📯Demo on Colab](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) |
|
||||||
|
[🤗Demo on HF (Coming soon)]()
|
||||||
|
|
||||||
---
|
[](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) \
|
||||||
|
|
||||||
[](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb)
|
|
||||||
|
|
||||||
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
|
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
|
[](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
|
[](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \
|
||||||
@@ -20,8 +21,10 @@ Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.0
|
|||||||
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
|
- **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**.
|
||||||
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
|
- **Flexible.** Collaboration with Stable Diffusion for Image Editting.
|
||||||
|
|
||||||
<!-- [](https://youtu.be/wxWDt5UiwY8)
|
## News
|
||||||
<iframe width="560" height="315" src="https://youtu.be/wxWDt5UiwY8" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe> -->
|
[2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\
|
||||||
|
[2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \
|
||||||
|
[2023/03/22] Code is available Now!
|
||||||
|
|
||||||
<details open>
|
<details open>
|
||||||
<summary><font size="4">
|
<summary><font size="4">
|
||||||
@@ -30,15 +33,18 @@ Description
|
|||||||
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
|
<img src=".asset/hero_figure.png" alt="ODinW" width="100%">
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
|
|
||||||
- [x] Release inference code and demo.
|
- [x] Release inference code and demo.
|
||||||
- [x] Release checkpoints.
|
- [x] Release checkpoints.
|
||||||
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
|
- [ ] Grounding DINO with Stable Diffusion and GLIGEN demos.
|
||||||
|
- [ ] Release training codes.
|
||||||
|
|
||||||
## Install
|
## Install
|
||||||
|
|
||||||
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set.
|
If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e .
|
pip install -e .
|
||||||
@@ -46,15 +52,16 @@ pip install -e .
|
|||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
|
|
||||||
See the `demo/inference_on_a_image.py` for more details.
|
|
||||||
```bash
|
```bash
|
||||||
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
||||||
-c /path/to/config \
|
-c /path/to/config \
|
||||||
-p /path/to/checkpoint \
|
-p /path/to/checkpoint \
|
||||||
-i .asset/cats.png \
|
-i .asset/cats.png \
|
||||||
-o "outputs/0" \
|
-o "outputs/0" \
|
||||||
-t "cat ear."
|
-t "cat ear." \
|
||||||
|
[--cpu-only] # open it for cpu mode
|
||||||
```
|
```
|
||||||
|
See the `demo/inference_on_a_image.py` for more details.
|
||||||
|
|
||||||
## Checkpoints
|
## Checkpoints
|
||||||
|
|
||||||
@@ -68,6 +75,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
|||||||
<th>Data</th>
|
<th>Data</th>
|
||||||
<th>box AP on COCO</th>
|
<th>box AP on COCO</th>
|
||||||
<th>Checkpoint</th>
|
<th>Checkpoint</th>
|
||||||
|
<th>Config</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
@@ -78,6 +86,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \
|
|||||||
<td>O365,GoldG,Cap4M</td>
|
<td>O365,GoldG,Cap4M</td>
|
||||||
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
|
<td>48.4 (zero-shot) / 57.2 (fine-tune)</td>
|
||||||
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">link</a></td>
|
<td><a href="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth">link</a></td>
|
||||||
|
<td><a href="https://github.com/IDEA-Research/GroundingDINO/blob/main/groundingdino/config/GroundingDINO_SwinT_OGC.py">link</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
|
@@ -39,7 +39,13 @@ def plot_boxes_to_image(image_pil, tgt):
|
|||||||
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
||||||
# draw.text((x0, y0), str(label), fill=color)
|
# draw.text((x0, y0), str(label), fill=color)
|
||||||
|
|
||||||
bbox = draw.textbbox((x0, y0), str(label))
|
font = ImageFont.load_default()
|
||||||
|
if hasattr(font, "getbbox"):
|
||||||
|
bbox = draw.textbbox((x0, y0), str(label), font)
|
||||||
|
else:
|
||||||
|
w, h = draw.textsize(str(label), font)
|
||||||
|
bbox = (x0, y0, w + x0, y0 + h)
|
||||||
|
# bbox = draw.textbbox((x0, y0), str(label))
|
||||||
draw.rectangle(bbox, fill=color)
|
draw.rectangle(bbox, fill=color)
|
||||||
draw.text((x0, y0), str(label), fill="white")
|
draw.text((x0, y0), str(label), fill="white")
|
||||||
|
|
||||||
@@ -63,9 +69,9 @@ def load_image(image_path):
|
|||||||
return image_pil, image
|
return image_pil, image
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_config_path, model_checkpoint_path):
|
def load_model(model_config_path, model_checkpoint_path, cpu_only=False):
|
||||||
args = SLConfig.fromfile(model_config_path)
|
args = SLConfig.fromfile(model_config_path)
|
||||||
args.device = "cuda"
|
args.device = "cuda" if not cpu_only else "cpu"
|
||||||
model = build_model(args)
|
model = build_model(args)
|
||||||
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
||||||
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
||||||
@@ -74,13 +80,14 @@ def load_model(model_config_path, model_checkpoint_path):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True):
|
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False):
|
||||||
caption = caption.lower()
|
caption = caption.lower()
|
||||||
caption = caption.strip()
|
caption = caption.strip()
|
||||||
if not caption.endswith("."):
|
if not caption.endswith("."):
|
||||||
caption = caption + "."
|
caption = caption + "."
|
||||||
model = model.cuda()
|
device = "cuda" if not cpu_only else "cpu"
|
||||||
image = image.cuda()
|
model = model.to(device)
|
||||||
|
image = image.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = model(image[None], captions=[caption])
|
outputs = model(image[None], captions=[caption])
|
||||||
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
|
||||||
@@ -125,6 +132,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
|
||||||
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
|
||||||
|
|
||||||
|
parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# cfg
|
# cfg
|
||||||
@@ -141,14 +150,14 @@ if __name__ == "__main__":
|
|||||||
# load image
|
# load image
|
||||||
image_pil, image = load_image(image_path)
|
image_pil, image = load_image(image_path)
|
||||||
# load model
|
# load model
|
||||||
model = load_model(config_file, checkpoint_path)
|
model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only)
|
||||||
|
|
||||||
# visualize raw image
|
# visualize raw image
|
||||||
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))
|
||||||
|
|
||||||
# run model
|
# run model
|
||||||
boxes_filt, pred_phrases = get_grounding_output(
|
boxes_filt, pred_phrases = get_grounding_output(
|
||||||
model, image, text_prompt, box_threshold, text_threshold
|
model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only
|
||||||
)
|
)
|
||||||
|
|
||||||
# visualize pred
|
# visualize pred
|
||||||
|
@@ -111,11 +111,11 @@ class PositionEmbeddingSineHW(nn.Module):
|
|||||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats)
|
dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats)
|
||||||
pos_x = x_embed[:, :, :, None] / dim_tx
|
pos_x = x_embed[:, :, :, None] / dim_tx
|
||||||
|
|
||||||
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats)
|
dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats)
|
||||||
pos_y = y_embed[:, :, :, None] / dim_ty
|
pos_y = y_embed[:, :, :, None] / dim_ty
|
||||||
|
|
||||||
pos_x = torch.stack(
|
pos_x = torch.stack(
|
||||||
|
@@ -25,7 +25,10 @@ from torch.autograd import Function
|
|||||||
from torch.autograd.function import once_differentiable
|
from torch.autograd.function import once_differentiable
|
||||||
from torch.nn.init import constant_, xavier_uniform_
|
from torch.nn.init import constant_, xavier_uniform_
|
||||||
|
|
||||||
from groundingdino import _C
|
try:
|
||||||
|
from groundingdino import _C
|
||||||
|
except:
|
||||||
|
warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!")
|
||||||
|
|
||||||
|
|
||||||
# helpers
|
# helpers
|
||||||
@@ -323,6 +326,7 @@ class MultiScaleDeformableAttention(nn.Module):
|
|||||||
reference_points.shape[-1]
|
reference_points.shape[-1]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch.cuda.is_available() and value.is_cuda:
|
if torch.cuda.is_available() and value.is_cuda:
|
||||||
halffloat = False
|
halffloat = False
|
||||||
if value.dtype == torch.float16:
|
if value.dtype == torch.float16:
|
||||||
|
@@ -206,7 +206,7 @@ def gen_sineembed_for_position(pos_tensor):
|
|||||||
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
# sineembed_tensor = torch.zeros(n_query, bs, 256)
|
||||||
scale = 2 * math.pi
|
scale = 2 * math.pi
|
||||||
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
|
||||||
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
|
dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128)
|
||||||
x_embed = pos_tensor[:, :, 0] * scale
|
x_embed = pos_tensor[:, :, 0] * scale
|
||||||
y_embed = pos_tensor[:, :, 1] * scale
|
y_embed = pos_tensor[:, :, 1] * scale
|
||||||
pos_x = x_embed[:, :, None] / dim_t
|
pos_x = x_embed[:, :, None] / dim_t
|
||||||
|
@@ -6,4 +6,5 @@ yapf
|
|||||||
timm
|
timm
|
||||||
numpy
|
numpy
|
||||||
opencv-python
|
opencv-python
|
||||||
supervision==0.3.2
|
supervision==0.3.2
|
||||||
|
pycocotools
|
Reference in New Issue
Block a user