diff --git a/groundingdino/util/get_tokenlizer.py b/groundingdino/util/get_tokenlizer.py index f7dcf7e..dd2d972 100644 --- a/groundingdino/util/get_tokenlizer.py +++ b/groundingdino/util/get_tokenlizer.py @@ -1,5 +1,5 @@ from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast - +import os def get_tokenlizer(text_encoder_type): if not isinstance(text_encoder_type, str): @@ -8,6 +8,8 @@ def get_tokenlizer(text_encoder_type): text_encoder_type = text_encoder_type.text_encoder_type elif text_encoder_type.get("text_encoder_type", False): text_encoder_type = text_encoder_type.get("text_encoder_type") + elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type): + pass else: raise ValueError( "Unknown type of text_encoder_type: {}".format(type(text_encoder_type)) @@ -19,8 +21,9 @@ def get_tokenlizer(text_encoder_type): def get_pretrained_language_model(text_encoder_type): - if text_encoder_type == "bert-base-uncased": + if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)): return BertModel.from_pretrained(text_encoder_type) if text_encoder_type == "roberta-base": return RobertaModel.from_pretrained(text_encoder_type) + raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))