Add ALBERT to AutoClasses
This commit is contained in:
committed by
Lysandre Debut
parent
4a666885b5
commit
ecf15ebf3b
@@ -28,6 +28,7 @@ from .tokenization_xlm import XLMTokenizer
|
||||
from .tokenization_roberta import RobertaTokenizer
|
||||
from .tokenization_distilbert import DistilBertTokenizer
|
||||
from .tokenization_camembert import CamembertTokenizer
|
||||
from .tokenization_albert import AlbertTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,16 +43,17 @@ class AutoTokenizer(object):
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- contains `albert`: AlbertTokenizer (ALBERT model)
|
||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- contains `bert`: BertTokenizer (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- contains `xlm`: XLMTokenizer (XLM model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throw an error).
|
||||
"""
|
||||
@@ -66,16 +68,17 @@ class AutoTokenizer(object):
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- contains `albert`: AlbertTokenizer (ALBERT model)
|
||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- contains `bert`: BertTokenizer (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- contains `xlm`: XLMTokenizer (XLM model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
@@ -109,6 +112,8 @@ class AutoTokenizer(object):
|
||||
"""
|
||||
if 'distilbert' in pretrained_model_name_or_path:
|
||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif 'albert' in pretrained_model_name_or_path:
|
||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif 'camembert' in pretrained_model_name_or_path:
|
||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
elif 'roberta' in pretrained_model_name_or_path:
|
||||
@@ -129,4 +134,4 @@ class AutoTokenizer(object):
|
||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||
"'xlm', 'roberta', 'camembert', 'ctrl'".format(pretrained_model_name_or_path))
|
||||
"'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
|
||||
|
||||
Reference in New Issue
Block a user