fix distilbert in auto tokenizer
This commit is contained in:
@@ -94,13 +94,13 @@ class AutoTokenizer(object):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache.
|
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache.
|
||||||
config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
|
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
if 'roberta' in pretrained_model_name_or_path:
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif 'bert' in pretrained_model_name_or_path:
|
||||||
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user