diff --git a/LICENSE b/LICENSE index b1395e9..f1460f5 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2020 - present, Facebook, Inc + Copyright 2023 - present, IDEA Research. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 50e516a..8704634 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ PyTorch implementation and pretrained models for Grounding DINO. For details, se ## :fire: News +- **`2023/06/17`**: We provide an example to evaluat Grounding DINO on COCO zero-shot performance. - **`2023/04/15`**: Refer to [CV in the Wild Readings](https://github.com/Computer-Vision-in-the-Wild/CVinW_Readings) for those who are interested in open-set recognition! - **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings. - **`2023/04/08`**: We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings. @@ -129,24 +130,16 @@ cd GroundingDINO/ Install the required dependencies in the current directory. ```bash -pip3 install -q -e . +pip install -e . ``` -Create a new directory called "weights" to store the model weights. + +Download pre-trained model weights. ```bash mkdir weights -``` - -Change the current directory to the "weights" folder. - -```bash cd weights -``` - -Download the model weights file. - -```bash wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth +cd .. ``` ## :arrow_forward: Demo @@ -201,6 +194,19 @@ We also provide a demo code to integrate Grounding DINO with Gradio Web UI. See - We release [demos](demo/image_editing_with_groundingdino_gligen.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [GLIGEN](https://github.com/gligen/GLIGEN) for more controllable image editings. - We release [demos](demo/image_editing_with_groundingdino_stablediffusion.ipynb) to combine [Grounding DINO](https://arxiv.org/abs/2303.05499) with [Stable Diffusion](https://github.com/Stability-AI/StableDiffusion) for image editings. +## COCO Zero-shot Evaluations + +We provide an example to evaluate Grounding DINO zero-shot performance on COCO. The results should be **48.5**. + +```bash +CUDA_VISIBLE_DEVICES=0 \ +python demo/test_ap_on_coco.py \ + -c groundingdino/config/GroundingDINO_SwinT_OGC.py \ + -p weights/groundingdino_swint_ogc.pth \ + --anno_path /path/to/annoataions/ie/instances_val2017.json \ + --image_dir /path/to/imagedir/ie/val2017 +``` + ## :luggage: Checkpoints diff --git a/demo/test_ap_on_coco.py b/demo/test_ap_on_coco.py new file mode 100644 index 0000000..59ce6a2 --- /dev/null +++ b/demo/test_ap_on_coco.py @@ -0,0 +1,233 @@ +import argparse +import os +import sys +import time + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, DistributedSampler + +from groundingdino.models import build_model +import groundingdino.datasets.transforms as T +from groundingdino.util import box_ops, get_tokenlizer +from groundingdino.util.misc import clean_state_dict, collate_fn +from groundingdino.util.slconfig import SLConfig + +# from torchvision.datasets import CocoDetection +import torchvision + +from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span +from groundingdino.datasets.cocogrounding_eval import CocoGroundingEvaluator + + +def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): + args = SLConfig.fromfile(model_config_path) + args.device = device + model = build_model(args) + checkpoint = torch.load(model_checkpoint_path, map_location="cpu") + model.load_state_dict(clean_state_dict(checkpoint["ema_model"]), strict=False) + model.eval() + return model + + +class CocoDetection(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, ann_file, transforms): + super().__init__(img_folder, ann_file) + self._transforms = transforms + + def __getitem__(self, idx): + img, target = super().__getitem__(idx) # target: list + + # import ipdb; ipdb.set_trace() + + w, h = img.size + boxes = [obj["bbox"] for obj in target] + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + # filt invalid boxes/masks/keypoints + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + + target_new = {} + image_id = self.ids[idx] + target_new["image_id"] = image_id + target_new["boxes"] = boxes + target_new["orig_size"] = torch.as_tensor([int(h), int(w)]) + + if self._transforms is not None: + img, target = self._transforms(img, target_new) + + return img, target + + +class PostProcessCocoGrounding(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + def __init__(self, num_select=300, coco_api=None, tokenlizer=None) -> None: + super().__init__() + self.num_select = num_select + + assert coco_api is not None + category_dict = coco_api.dataset['categories'] + cat_list = [item['name'] for item in category_dict] + captions, cat2tokenspan = build_captions_and_token_span(cat_list, True) + tokenspanlist = [cat2tokenspan[cat] for cat in cat_list] + positive_map = create_positive_map_from_span( + tokenlizer(captions), tokenspanlist) # 80, 256. normed + + id_map = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 13, 12: 14, 13: 15, 14: 16, 15: 17, 16: 18, 17: 19, 18: 20, 19: 21, 20: 22, 21: 23, 22: 24, 23: 25, 24: 27, 25: 28, 26: 31, 27: 32, 28: 33, 29: 34, 30: 35, 31: 36, 32: 37, 33: 38, 34: 39, 35: 40, 36: 41, 37: 42, 38: 43, 39: 44, 40: 46, + 41: 47, 42: 48, 43: 49, 44: 50, 45: 51, 46: 52, 47: 53, 48: 54, 49: 55, 50: 56, 51: 57, 52: 58, 53: 59, 54: 60, 55: 61, 56: 62, 57: 63, 58: 64, 59: 65, 60: 67, 61: 70, 62: 72, 63: 73, 64: 74, 65: 75, 66: 76, 67: 77, 68: 78, 69: 79, 70: 80, 71: 81, 72: 82, 73: 84, 74: 85, 75: 86, 76: 87, 77: 88, 78: 89, 79: 90} + + # build a mapping from label_id to pos_map + new_pos_map = torch.zeros((91, 256)) + for k, v in id_map.items(): + new_pos_map[v] = positive_map[k] + self.positive_map = new_pos_map + + @torch.no_grad() + def forward(self, outputs, target_sizes, not_to_xyxy=False): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + num_select = self.num_select + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + # pos map to logit + prob_to_token = out_logits.sigmoid() # bs, 100, 256 + pos_maps = self.positive_map.to(prob_to_token.device) + # (bs, 100, 256) @ (91, 256).T -> (bs, 100, 91) + prob_to_label = prob_to_token @ pos_maps.T + + # if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': + # import ipdb; ipdb.set_trace() + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = prob_to_label + topk_values, topk_indexes = torch.topk( + prob.view(out_logits.shape[0], -1), num_select, dim=1) + scores = topk_values + topk_boxes = topk_indexes // prob.shape[2] + labels = topk_indexes % prob.shape[2] + + if not_to_xyxy: + boxes = out_bbox + else: + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + + boxes = torch.gather( + boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) + + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} + for s, l, b in zip(scores, labels, boxes)] + + return results + + +def main(args): + # config + cfg = SLConfig.fromfile(args.config_file) + + # build model + model = load_model(args.config_file, args.checkpoint_path) + model = model.to(args.device) + model = model.eval() + + # build dataloader + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + dataset = CocoDetection( + args.image_dir, args.anno_path, transforms=transform) + data_loader = DataLoader( + dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) + + # build post processor + tokenlizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type) + postprocessor = PostProcessCocoGrounding( + coco_api=dataset.coco, tokenlizer=tokenlizer) + + # build evaluator + evaluator = CocoGroundingEvaluator( + dataset.coco, iou_types=("bbox",), useCats=True) + + # build captions + category_dict = dataset.coco.dataset['categories'] + cat_list = [item['name'] for item in category_dict] + caption = " . ".join(cat_list) + ' .' + print("Input text prompt:", caption) + + # run inference + start = time.time() + for i, (images, targets) in enumerate(data_loader): + # get images and captions + images = images.tensors.to(args.device) + bs = images.shape[0] + input_captions = [caption] * bs + + # feed to the model + outputs = model(images, captions=input_captions) + + orig_target_sizes = torch.stack( + [t["orig_size"] for t in targets], dim=0).to(images.device) + results = postprocessor(outputs, orig_target_sizes) + cocogrounding_res = { + target["image_id"]: output for target, output in zip(targets, results)} + evaluator.update(cocogrounding_res) + + if (i+1) % 30 == 0: + used_time = time.time() - start + eta = len(data_loader) / (i+1e-5) * used_time - used_time + print( + f"processed {i}/{len(data_loader)} images. time: {used_time:.2f}s, ETA: {eta:.2f}s") + + evaluator.synchronize_between_processes() + evaluator.accumulate() + evaluator.summarize() + + print("Final results:", evaluator.coco_eval["bbox"].stats.tolist()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Grounding DINO eval on COCO", add_help=True) + # load model + parser.add_argument("--config_file", "-c", type=str, + required=True, help="path to config file") + parser.add_argument( + "--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" + ) + parser.add_argument("--device", type=str, default="cuda", + help="running device (default: cuda)") + + # post processing + parser.add_argument("--num_select", type=int, default=300, + help="number of topk to select") + + # coco info + parser.add_argument("--anno_path", type=str, + required=True, help="coco root") + parser.add_argument("--image_dir", type=str, + required=True, help="coco image dir") + parser.add_argument("--num_workers", type=int, default=4, + help="number of workers for dataloader") + args = parser.parse_args() + + main(args) diff --git a/groundingdino/datasets/cocogrounding_eval.py b/groundingdino/datasets/cocogrounding_eval.py new file mode 100644 index 0000000..7693a18 --- /dev/null +++ b/groundingdino/datasets/cocogrounding_eval.py @@ -0,0 +1,269 @@ +# ------------------------------------------------------------------------ +# Grounding DINO. Midified by Shilong Liu. +# url: https://github.com/IDEA-Research/GroundingDINO +# Copyright (c) 2023 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +COCO evaluator that works in distributed mode. + +Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py +The difference is that there is less copy-pasting from pycocotools +in the end of the file, as python3 can suppress prints with contextlib +""" +import contextlib +import copy +import os + +import numpy as np +import pycocotools.mask as mask_util +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval + +from groundingdino.util.misc import all_gather + + +class CocoGroundingEvaluator(object): + def __init__(self, coco_gt, iou_types, useCats=True): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + self.coco_eval[iou_type].useCats = useCats + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + self.useCats = useCats + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + + # suppress pycocotools prints + with open(os.devnull, "w") as devnull: + with contextlib.redirect_stdout(devnull): + coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() + + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + coco_eval.params.useCats = self.useCats + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "keypoints": keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = all_gather(img_ids) + all_eval_imgs = all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + + +def evaluate(self): + """ + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + """ + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/groundingdino/models/GroundingDINO/groundingdino.py b/groundingdino/models/GroundingDINO/groundingdino.py index 052df62..a5758fd 100644 --- a/groundingdino/models/GroundingDINO/groundingdino.py +++ b/groundingdino/models/GroundingDINO/groundingdino.py @@ -228,7 +228,6 @@ class GroundingDINO(nn.Module): captions = kw["captions"] else: captions = [t["caption"] for t in targets] - len(captions) # encoder texts tokenized = self.tokenizer(captions, padding="longest", return_tensors="pt").to( diff --git a/groundingdino/version.py b/groundingdino/version.py deleted file mode 100644 index 3dc1f76..0000000 --- a/groundingdino/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.1.0"