@@ -1,5 +1,5 @@
|
|||||||
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
|
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
|
||||||
|
import os
|
||||||
|
|
||||||
def get_tokenlizer(text_encoder_type):
|
def get_tokenlizer(text_encoder_type):
|
||||||
if not isinstance(text_encoder_type, str):
|
if not isinstance(text_encoder_type, str):
|
||||||
@@ -8,6 +8,8 @@ def get_tokenlizer(text_encoder_type):
|
|||||||
text_encoder_type = text_encoder_type.text_encoder_type
|
text_encoder_type = text_encoder_type.text_encoder_type
|
||||||
elif text_encoder_type.get("text_encoder_type", False):
|
elif text_encoder_type.get("text_encoder_type", False):
|
||||||
text_encoder_type = text_encoder_type.get("text_encoder_type")
|
text_encoder_type = text_encoder_type.get("text_encoder_type")
|
||||||
|
elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
|
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
|
||||||
@@ -19,8 +21,9 @@ def get_tokenlizer(text_encoder_type):
|
|||||||
|
|
||||||
|
|
||||||
def get_pretrained_language_model(text_encoder_type):
|
def get_pretrained_language_model(text_encoder_type):
|
||||||
if text_encoder_type == "bert-base-uncased":
|
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
|
||||||
return BertModel.from_pretrained(text_encoder_type)
|
return BertModel.from_pretrained(text_encoder_type)
|
||||||
if text_encoder_type == "roberta-base":
|
if text_encoder_type == "roberta-base":
|
||||||
return RobertaModel.from_pretrained(text_encoder_type)
|
return RobertaModel.from_pretrained(text_encoder_type)
|
||||||
|
|
||||||
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
|
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))
|
||||||
|
Reference in New Issue
Block a user