@@ -71,7 +71,7 @@ def predict(
|
||||
tokenized = tokenizer(caption)
|
||||
|
||||
phrases = [
|
||||
get_phrases_from_posmap(logit > text_threshold, tokenized, caption).replace('.', '')
|
||||
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
|
||||
for logit
|
||||
in logits
|
||||
]
|
||||
|
@@ -7,6 +7,7 @@ from typing import Any, Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from groundingdino.util.slconfig import SLConfig
|
||||
|
||||
@@ -595,27 +596,13 @@ def targets_to(targets: List[Dict[str, Any]], device):
|
||||
]
|
||||
|
||||
|
||||
def get_phrases_from_posmap(posmap: torch.BoolTensor, tokenlized, caption: str):
|
||||
def get_phrases_from_posmap(
|
||||
posmap: torch.BoolTensor, tokenized: Dict, tokenizer: AutoTokenizer
|
||||
):
|
||||
assert isinstance(posmap, torch.Tensor), "posmap must be torch.Tensor"
|
||||
if posmap.dim() == 1:
|
||||
non_zero_idx = posmap.nonzero(as_tuple=True)[0].tolist()
|
||||
words_list = caption.split()
|
||||
|
||||
# build word idx list
|
||||
words_idx_used_list = []
|
||||
for idx in non_zero_idx:
|
||||
word_idx = tokenlized.token_to_word(idx)
|
||||
if word_idx is not None:
|
||||
words_idx_used_list.append(word_idx)
|
||||
words_idx_used_list = set(words_idx_used_list)
|
||||
|
||||
# build phrase
|
||||
words_used_list = []
|
||||
for idx, word in enumerate(words_list):
|
||||
if idx in words_idx_used_list:
|
||||
words_used_list.append(word)
|
||||
|
||||
sentence_res = " ".join(words_used_list)
|
||||
return sentence_res
|
||||
token_ids = [tokenized["input_ids"][i] for i in non_zero_idx]
|
||||
return tokenizer.decode(token_ids)
|
||||
else:
|
||||
raise NotImplementedError("posmap must be 1-dim")
|
||||
|
Reference in New Issue
Block a user