diff --git a/README.md b/README.md index 349dd68..45634ef 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,10 @@ # Grounding DINO +[📃Paper](https://arxiv.org/abs/2303.05499) | +[📽️Video](https://www.youtube.com/watch?v=wxWDt5UiwY8) | +[📯Demo on Colab](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) | +[🤗Demo on HF (Coming soon)]() ---- - -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) - +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) \ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-mscoco)](https://paperswithcode.com/sota/zero-shot-object-detection-on-mscoco?p=grounding-dino-marrying-dino-with-grounded) \ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/zero-shot-object-detection-on-odinw)](https://paperswithcode.com/sota/zero-shot-object-detection-on-odinw?p=grounding-dino-marrying-dino-with-grounded) \ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/grounding-dino-marrying-dino-with-grounded/object-detection-on-coco-minival)](https://paperswithcode.com/sota/object-detection-on-coco-minival?p=grounding-dino-marrying-dino-with-grounded) \ @@ -20,8 +21,10 @@ Official pytorch implementation of [Grounding DINO](https://arxiv.org/abs/2303.0 - **High Performancce.** COCO zero-shot **52.5 AP** (training without COCO data!). COCO fine-tune **63.0 AP**. - **Flexible.** Collaboration with Stable Diffusion for Image Editting. - +## News +[2023/03/27] Support CPU-only mode. Now the model can run on machines without GPUs.\ +[2023/03/25] A [demo](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/zero-shot-object-detection-with-grounding-dino.ipynb) for Grounding DINO is available at Colab. Thanks to @Piotr! \ +[2023/03/22] Code is available Now!
@@ -30,15 +33,18 @@ Description ODinW
+ + ## TODO - [x] Release inference code and demo. - [x] Release checkpoints. - [ ] Grounding DINO with Stable Diffusion and GLIGEN demos. +- [ ] Release training codes. ## Install -If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. +If you have a CUDA environment, please make sure the environment variable `CUDA_HOME` is set. It will be compiled under CPU-only mode if no CUDA available. ```bash pip install -e . @@ -46,15 +52,16 @@ pip install -e . ## Demo -See the `demo/inference_on_a_image.py` for more details. ```bash CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \ -c /path/to/config \ -p /path/to/checkpoint \ -i .asset/cats.png \ -o "outputs/0" \ - -t "cat ear." + -t "cat ear." \ + [--cpu-only] # open it for cpu mode ``` +See the `demo/inference_on_a_image.py` for more details. ## Checkpoints @@ -68,6 +75,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \ Data box AP on COCO Checkpoint + Config @@ -78,6 +86,7 @@ CUDA_VISIBLE_DEVICES=6 python demo/inference_on_a_image.py \ O365,GoldG,Cap4M 48.4 (zero-shot) / 57.2 (fine-tune) link + link diff --git a/demo/inference_on_a_image.py b/demo/inference_on_a_image.py index 79406ae..8c91899 100644 --- a/demo/inference_on_a_image.py +++ b/demo/inference_on_a_image.py @@ -39,7 +39,13 @@ def plot_boxes_to_image(image_pil, tgt): draw.rectangle([x0, y0, x1, y1], outline=color, width=6) # draw.text((x0, y0), str(label), fill=color) - bbox = draw.textbbox((x0, y0), str(label)) + font = ImageFont.load_default() + if hasattr(font, "getbbox"): + bbox = draw.textbbox((x0, y0), str(label), font) + else: + w, h = draw.textsize(str(label), font) + bbox = (x0, y0, w + x0, y0 + h) + # bbox = draw.textbbox((x0, y0), str(label)) draw.rectangle(bbox, fill=color) draw.text((x0, y0), str(label), fill="white") @@ -63,9 +69,9 @@ def load_image(image_path): return image_pil, image -def load_model(model_config_path, model_checkpoint_path): +def load_model(model_config_path, model_checkpoint_path, cpu_only=False): args = SLConfig.fromfile(model_config_path) - args.device = "cuda" + args.device = "cuda" if not cpu_only else "cpu" model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) @@ -74,13 +80,14 @@ def load_model(model_config_path, model_checkpoint_path): return model -def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True): +def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." - model = model.cuda() - image = image.cuda() + device = "cuda" if not cpu_only else "cpu" + model = model.to(device) + image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) @@ -125,6 +132,8 @@ if __name__ == "__main__": parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") + + parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False") args = parser.parse_args() # cfg @@ -141,14 +150,14 @@ if __name__ == "__main__": # load image image_pil, image = load_image(image_path) # load model - model = load_model(config_file, checkpoint_path) + model = load_model(config_file, checkpoint_path, cpu_only=args.cpu_only) # visualize raw image image_pil.save(os.path.join(output_dir, "raw_image.jpg")) # run model boxes_filt, pred_phrases = get_grounding_output( - model, image, text_prompt, box_threshold, text_threshold + model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only ) # visualize pred diff --git a/groundingdino/models/GroundingDINO/backbone/position_encoding.py b/groundingdino/models/GroundingDINO/backbone/position_encoding.py index 14b429c..eac7e89 100644 --- a/groundingdino/models/GroundingDINO/backbone/position_encoding.py +++ b/groundingdino/models/GroundingDINO/backbone/position_encoding.py @@ -111,11 +111,11 @@ class PositionEmbeddingSineHW(nn.Module): x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_tx = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_tx = self.temperatureW ** (2 * (dim_tx // 2) / self.num_pos_feats) + dim_tx = self.temperatureW ** (2 * (torch.div(dim_tx, 2, rounding_mode='floor')) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_tx dim_ty = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_ty = self.temperatureH ** (2 * (dim_ty // 2) / self.num_pos_feats) + dim_ty = self.temperatureH ** (2 * (torch.div(dim_ty, 2, rounding_mode='floor')) / self.num_pos_feats) pos_y = y_embed[:, :, :, None] / dim_ty pos_x = torch.stack( diff --git a/groundingdino/models/GroundingDINO/ms_deform_attn.py b/groundingdino/models/GroundingDINO/ms_deform_attn.py index a51d7d2..489d501 100644 --- a/groundingdino/models/GroundingDINO/ms_deform_attn.py +++ b/groundingdino/models/GroundingDINO/ms_deform_attn.py @@ -25,7 +25,10 @@ from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.init import constant_, xavier_uniform_ -from groundingdino import _C +try: + from groundingdino import _C +except: + warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only!") # helpers @@ -323,6 +326,7 @@ class MultiScaleDeformableAttention(nn.Module): reference_points.shape[-1] ) ) + if torch.cuda.is_available() and value.is_cuda: halffloat = False if value.dtype == torch.float16: diff --git a/groundingdino/models/GroundingDINO/utils.py b/groundingdino/models/GroundingDINO/utils.py index caf0f1b..5bd18f7 100644 --- a/groundingdino/models/GroundingDINO/utils.py +++ b/groundingdino/models/GroundingDINO/utils.py @@ -206,7 +206,7 @@ def gen_sineembed_for_position(pos_tensor): # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) - dim_t = 10000 ** (2 * (dim_t // 2) / 128) + dim_t = 10000 ** (2 * (torch.div(dim_t, 2, rounding_mode='floor')) / 128) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t diff --git a/requirements.txt b/requirements.txt index 63773b7..f52ed0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ yapf timm numpy opencv-python -supervision==0.3.2 \ No newline at end of file +supervision==0.3.2 +pycocotools \ No newline at end of file