@@ -6,6 +6,7 @@ import supervision as sv
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.ops import box_convert
|
from torchvision.ops import box_convert
|
||||||
|
import bisect
|
||||||
|
|
||||||
import groundingdino.datasets.transforms as T
|
import groundingdino.datasets.transforms as T
|
||||||
from groundingdino.models import build_model
|
from groundingdino.models import build_model
|
||||||
@@ -55,7 +56,8 @@ def predict(
|
|||||||
caption: str,
|
caption: str,
|
||||||
box_threshold: float,
|
box_threshold: float,
|
||||||
text_threshold: float,
|
text_threshold: float,
|
||||||
device: str = "cuda"
|
device: str = "cuda",
|
||||||
|
remove_combined: bool = False
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
|
||||||
caption = preprocess_caption(caption=caption)
|
caption = preprocess_caption(caption=caption)
|
||||||
|
|
||||||
@@ -74,12 +76,23 @@ def predict(
|
|||||||
|
|
||||||
tokenizer = model.tokenizer
|
tokenizer = model.tokenizer
|
||||||
tokenized = tokenizer(caption)
|
tokenized = tokenizer(caption)
|
||||||
|
|
||||||
phrases = [
|
if remove_combined:
|
||||||
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
sep_idx = [i for i in range(len(tokenized['input_ids'])) if tokenized['input_ids'][i] in [101, 102, 1012]]
|
||||||
for logit
|
|
||||||
in logits
|
phrases = []
|
||||||
]
|
for logit in logits:
|
||||||
|
max_idx = logit.argmax()
|
||||||
|
insert_idx = bisect.bisect_left(sep_idx, max_idx)
|
||||||
|
right_idx = sep_idx[insert_idx]
|
||||||
|
left_idx = sep_idx[insert_idx - 1]
|
||||||
|
phrases.append(get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer, left_idx, right_idx).replace('.', ''))
|
||||||
|
else:
|
||||||
|
phrases = [
|
||||||
|
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
||||||
|
for logit
|
||||||
|
in logits
|
||||||
|
]
|
||||||
|
|
||||||
return boxes, logits.max(dim=1)[0], phrases
|
return boxes, logits.max(dim=1)[0], phrases
|
||||||
|
|
||||||
|
@@ -597,10 +597,12 @@ def targets_to(targets: List[Dict[str, Any]], device):
|
|||||||
|
|
||||||
|
|
||||||
def get_phrases_from_posmap(
|
def get_phrases_from_posmap(
|
||||||
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
|
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer, left_idx: int = 0, right_idx: int = 255
|
||||||
):
|
):
|
||||||
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
|
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
|
||||||
if posmap.dim() == 1:
|
if posmap.dim() == 1:
|
||||||
|
posmap[0: left_idx + 1] = False
|
||||||
|
posmap[right_idx:] = False
|
||||||
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
|
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
|
||||||
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
|
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
|
||||||
return tokenizer.decode(token_ids)
|
return tokenizer.decode(token_ids)
|
||||||
|
Reference in New Issue
Block a user