support gsam2 image predictor model
This commit is contained in:
100
grounding_dino/groundingdino/util/vl_utils.py
Normal file
100
grounding_dino/groundingdino/util/vl_utils.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def create_positive_map_from_span(tokenized, token_span, max_text_len=256):
|
||||
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j
|
||||
Input:
|
||||
- tokenized:
|
||||
- input_ids: Tensor[1, ntokens]
|
||||
- attention_mask: Tensor[1, ntokens]
|
||||
- token_span: list with length num_boxes.
|
||||
- each item: [start_idx, end_idx]
|
||||
"""
|
||||
positive_map = torch.zeros((len(token_span), max_text_len), dtype=torch.float)
|
||||
for j, tok_list in enumerate(token_span):
|
||||
for (beg, end) in tok_list:
|
||||
beg_pos = tokenized.char_to_token(beg)
|
||||
end_pos = tokenized.char_to_token(end - 1)
|
||||
if beg_pos is None:
|
||||
try:
|
||||
beg_pos = tokenized.char_to_token(beg + 1)
|
||||
if beg_pos is None:
|
||||
beg_pos = tokenized.char_to_token(beg + 2)
|
||||
except:
|
||||
beg_pos = None
|
||||
if end_pos is None:
|
||||
try:
|
||||
end_pos = tokenized.char_to_token(end - 2)
|
||||
if end_pos is None:
|
||||
end_pos = tokenized.char_to_token(end - 3)
|
||||
except:
|
||||
end_pos = None
|
||||
if beg_pos is None or end_pos is None:
|
||||
continue
|
||||
|
||||
assert beg_pos is not None and end_pos is not None
|
||||
if os.environ.get("SHILONG_DEBUG_ONLY_ONE_POS", None) == "TRUE":
|
||||
positive_map[j, beg_pos] = 1
|
||||
break
|
||||
else:
|
||||
positive_map[j, beg_pos : end_pos + 1].fill_(1)
|
||||
|
||||
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
|
||||
|
||||
|
||||
def build_captions_and_token_span(cat_list, force_lowercase):
|
||||
"""
|
||||
Return:
|
||||
captions: str
|
||||
cat2tokenspan: dict
|
||||
{
|
||||
'dog': [[0, 2]],
|
||||
...
|
||||
}
|
||||
"""
|
||||
|
||||
cat2tokenspan = {}
|
||||
captions = ""
|
||||
for catname in cat_list:
|
||||
class_name = catname
|
||||
if force_lowercase:
|
||||
class_name = class_name.lower()
|
||||
if "/" in class_name:
|
||||
class_name_list: List = class_name.strip().split("/")
|
||||
class_name_list.append(class_name)
|
||||
class_name: str = random.choice(class_name_list)
|
||||
|
||||
tokens_positive_i = []
|
||||
subnamelist = [i.strip() for i in class_name.strip().split(" ")]
|
||||
for subname in subnamelist:
|
||||
if len(subname) == 0:
|
||||
continue
|
||||
if len(captions) > 0:
|
||||
captions = captions + " "
|
||||
strat_idx = len(captions)
|
||||
end_idx = strat_idx + len(subname)
|
||||
tokens_positive_i.append([strat_idx, end_idx])
|
||||
captions = captions + subname
|
||||
|
||||
if len(tokens_positive_i) > 0:
|
||||
captions = captions + " ."
|
||||
cat2tokenspan[class_name] = tokens_positive_i
|
||||
|
||||
return captions, cat2tokenspan
|
||||
|
||||
|
||||
def build_id2posspan_and_caption(category_dict: dict):
|
||||
"""Build id2pos_span and caption from category_dict
|
||||
|
||||
Args:
|
||||
category_dict (dict): category_dict
|
||||
"""
|
||||
cat_list = [item["name"].lower() for item in category_dict]
|
||||
id2catname = {item["id"]: item["name"].lower() for item in category_dict}
|
||||
caption, cat2posspan = build_captions_and_token_span(cat_list, force_lowercase=True)
|
||||
id2posspan = {catid: cat2posspan[catname] for catid, catname in id2catname.items()}
|
||||
return id2posspan, caption
|
Reference in New Issue
Block a user