diff --git a/.asset/cat_dog.jpeg b/.asset/cat_dog.jpeg new file mode 100644 index 0000000..8b30a3c Binary files /dev/null and b/.asset/cat_dog.jpeg differ diff --git a/README.md b/README.md index 431dc56..7bf05af 100644 --- a/README.md +++ b/README.md @@ -151,13 +151,27 @@ nvidia-smi Replace `{GPU ID}`, `image_you_want_to_detect.jpg`, and `"dir you want to save the output"` with appropriate values in the following command ```bash CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \ --c /GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \ --p /GroundingDINO/weights/groundingdino_swint_ogc.pth \ +-c groundingdino/config/GroundingDINO_SwinT_OGC.py \ +-p weights/groundingdino_swint_ogc.pth \ -i image_you_want_to_detect.jpg \ -o "dir you want to save the output" \ -t "chair" [--cpu-only] # open it for cpu mode ``` + +If you would like to specify the phrases to detect, here is a demo: +```bash +CUDA_VISIBLE_DEVICES={GPU ID} python demo/inference_on_a_image.py \ +-c groundingdino/config/GroundingDINO_SwinT_OGC.py \ +-p /comp_robot/liushilong/data/pretrained/grounding_pretrain/groundingdino_swint_ogc.pth \ +-i .asset/cat_dog.jpeg \ +-o logs/1111 \ +-t "There is a cat and a dog in the image ." \ +--token_spans "[[[9, 10], [11, 14]], [[19, 20], [21, 24]]]" + [--cpu-only] # open it for cpu mode +``` +The token_spans specify the start and end positions of a phrases. For example, the first phrase is `[[9, 10], [11, 14]]`. `"There is a cat and a dog in the image ."[9:10] = 'a'`, `"There is a cat and a dog in the image ."[11:14] = 'cat'`. Hence it refere to the phrase `a cat` . + See the `demo/inference_on_a_image.py` for more details. **Running with Python:** diff --git a/demo/inference_on_a_image.py b/demo/inference_on_a_image.py index 207227b..0dd332f 100644 --- a/demo/inference_on_a_image.py +++ b/demo/inference_on_a_image.py @@ -11,6 +11,7 @@ from groundingdino.models import build_model from groundingdino.util import box_ops from groundingdino.util.slconfig import SLConfig from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap +from groundingdino.util.vl_utils import create_positive_map_from_span def plot_boxes_to_image(image_pil, tgt): @@ -80,7 +81,8 @@ def load_model(model_config_path, model_checkpoint_path, cpu_only=False): return model -def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, cpu_only=False): +def get_grounding_output(model, image, caption, box_threshold, text_threshold=None, with_logits=True, cpu_only=False, token_spans=None): + assert text_threshold is not None or token_spans is not None, "text_threshould and token_spans should not be None at the same time!" caption = caption.lower() caption = caption.strip() if not caption.endswith("."): @@ -90,29 +92,56 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) - logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) - boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) - logits.shape[0] + logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) + boxes = outputs["pred_boxes"][0] # (nq, 4) # filter output - logits_filt = logits.clone() - boxes_filt = boxes.clone() - filt_mask = logits_filt.max(dim=1)[0] > box_threshold - logits_filt = logits_filt[filt_mask] # num_filt, 256 - boxes_filt = boxes_filt[filt_mask] # num_filt, 4 - logits_filt.shape[0] + if token_spans is None: + logits_filt = logits.cpu().clone() + boxes_filt = boxes.cpu().clone() + filt_mask = logits_filt.max(dim=1)[0] > box_threshold + logits_filt = logits_filt[filt_mask] # num_filt, 256 + boxes_filt = boxes_filt[filt_mask] # num_filt, 4 + + # get phrase + tokenlizer = model.tokenizer + tokenized = tokenlizer(caption) + # build pred + pred_phrases = [] + for logit, box in zip(logits_filt, boxes_filt): + pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) + if with_logits: + pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") + else: + pred_phrases.append(pred_phrase) + else: + # given-phrase mode + positive_maps = create_positive_map_from_span( + model.tokenizer(text_prompt), + token_span=token_spans + ).to(image.device) # n_phrase, 256 + + logits_for_phrases = positive_maps @ logits.T # n_phrase, nq + all_logits = [] + all_phrases = [] + all_boxes = [] + for (token_span, logit_phr) in zip(token_spans, logits_for_phrases): + # get phrase + phrase = ' '.join([caption[_s:_e] for (_s, _e) in token_span]) + # get mask + filt_mask = logit_phr > box_threshold + # filt box + all_boxes.append(boxes[filt_mask]) + # filt logits + all_logits.append(logit_phr[filt_mask]) + if with_logits: + logit_phr_num = logit_phr[filt_mask] + all_phrases.extend([phrase + f"({str(logit.item())[:4]})" for logit in logit_phr_num]) + else: + all_phrases.extend([phrase for _ in range(len(filt_mask))]) + boxes_filt = torch.cat(all_boxes, dim=0).cpu() + pred_phrases = all_phrases - # get phrase - tokenlizer = model.tokenizer - tokenized = tokenlizer(caption) - # build pred - pred_phrases = [] - for logit, box in zip(logits_filt, boxes_filt): - pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) - if with_logits: - pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") - else: - pred_phrases.append(pred_phrase) return boxes_filt, pred_phrases @@ -132,6 +161,12 @@ if __name__ == "__main__": parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") + parser.add_argument("--token_spans", type=str, default=None, help= + "The positions of start and end positions of phrases of interest. \ + For example, a caption is 'a cat and a dog', \ + if you would like to detect 'cat', the token_spans should be '[[[2, 5]], ]', since 'a cat and a dog'[2:5] is 'cat'. \ + if you would like to detect 'a cat', the token_spans should be '[[[0, 1], [2, 5]], ]', since 'a cat and a dog'[0:1] is 'a', and 'a cat and a dog'[2:5] is 'cat'. \ + ") parser.add_argument("--cpu-only", action="store_true", help="running on cpu only!, default=False") args = parser.parse_args() @@ -144,6 +179,7 @@ if __name__ == "__main__": output_dir = args.output_dir box_threshold = args.box_threshold text_threshold = args.text_threshold + token_spans = args.token_spans # make dir os.makedirs(output_dir, exist_ok=True) @@ -155,9 +191,15 @@ if __name__ == "__main__": # visualize raw image image_pil.save(os.path.join(output_dir, "raw_image.jpg")) + # set the text_threshold to None if token_spans is set. + if token_spans is not None: + text_threshold = None + print("Using token_spans. Set the text_threshold to None.") + + # run model boxes_filt, pred_phrases = get_grounding_output( - model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only + model, image, text_prompt, box_threshold, text_threshold, cpu_only=args.cpu_only, token_spans=eval(token_spans) ) # visualize pred