support gsam2 image predictor model
This commit is contained in:
29
grounding_dino/groundingdino/util/get_tokenlizer.py
Normal file
29
grounding_dino/groundingdino/util/get_tokenlizer.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
|
||||
import os
|
||||
|
||||
def get_tokenlizer(text_encoder_type):
|
||||
if not isinstance(text_encoder_type, str):
|
||||
# print("text_encoder_type is not a str")
|
||||
if hasattr(text_encoder_type, "text_encoder_type"):
|
||||
text_encoder_type = text_encoder_type.text_encoder_type
|
||||
elif text_encoder_type.get("text_encoder_type", False):
|
||||
text_encoder_type = text_encoder_type.get("text_encoder_type")
|
||||
elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
|
||||
)
|
||||
print("final text_encoder_type: {}".format(text_encoder_type))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_pretrained_language_model(text_encoder_type):
|
||||
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
|
||||
return BertModel.from_pretrained(text_encoder_type)
|
||||
if text_encoder_type == "roberta-base":
|
||||
return RobertaModel.from_pretrained(text_encoder_type)
|
||||
|
||||
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
|
Reference in New Issue
Block a user