diff --git a/groundingdino/util/inference.py b/groundingdino/util/inference.py index fe3ac64..718bc7b 100644 --- a/groundingdino/util/inference.py +++ b/groundingdino/util/inference.py @@ -6,6 +6,7 @@ import supervision as sv import torch from PIL import Image from torchvision.ops import box_convert +import bisect import groundingdino.datasets.transforms as T from groundingdino.models import build_model @@ -55,7 +56,8 @@ def predict( caption: str, box_threshold: float, text_threshold: float, - device: str = "cuda" + device: str = "cuda", + remove_combined: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: caption = preprocess_caption(caption=caption) @@ -74,12 +76,23 @@ def predict( tokenizer = model.tokenizer tokenized = tokenizer(caption) - - phrases = [ - get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') - for logit - in logits - ] + + if remove_combined: + sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]] + + phrases = [] + for logit in logits: + max_idx = logit.argmax() + insert_idx = bisect.bisect_left(sep_idx, max_idx) + right_idx = sep_idx[insert_idx] + left_idx = sep_idx[insert_idx - 1] + phrases.append(get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', '')) + else: + phrases = [ + get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') + for logit + in logits + ] return boxes, logits.max(dim=1)[0], phrases diff --git a/groundingdino/util/utils.py b/groundingdino/util/utils.py index e9f0318..8cf83ae 100644 --- a/groundingdino/util/utils.py +++ b/groundingdino/util/utils.py @@ -597,10 +597,12 @@ def targets_to(targets: List[Dict[str, Any]], device): def get_phrases_from_posmap( - posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer + posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255 ): assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor" if posmap.dim() == 1: + posmap[0: left_idx + 1] = False + posmap[right_idx:] = False non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist() token_ids = [tokenized["input_ids"][i] for i in non_zero_idx] return tokenizer.decode(token_ids)