Add ALBERT to AutoClasses
This commit is contained in:
committed by
Lysandre Debut
parent
4a666885b5
commit
ecf15ebf3b
@@ -28,6 +28,7 @@ from .configuration_roberta import RobertaConfig
|
|||||||
from .configuration_distilbert import DistilBertConfig
|
from .configuration_distilbert import DistilBertConfig
|
||||||
from .configuration_ctrl import CTRLConfig
|
from .configuration_ctrl import CTRLConfig
|
||||||
from .configuration_camembert import CamembertConfig
|
from .configuration_camembert import CamembertConfig
|
||||||
|
from .configuration_albert import AlbertConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -44,14 +45,15 @@ class AutoConfig(object):
|
|||||||
The base model class to instantiate is selected as the first pattern matching
|
The base model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertConfig (DistilBERT model)
|
- contains `distilbert`: DistilBertConfig (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertConfig (ALBERT model)
|
||||||
|
- contains `camembert`: CamembertConfig (CamemBERT model)
|
||||||
|
- contains `roberta`: RobertaConfig (RoBERTa model)
|
||||||
- contains `bert`: BertConfig (Bert model)
|
- contains `bert`: BertConfig (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
|
||||||
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetConfig (XLNet model)
|
- contains `xlnet`: XLNetConfig (XLNet model)
|
||||||
- contains `xlm`: XLMConfig (XLM model)
|
- contains `xlm`: XLMConfig (XLM model)
|
||||||
- contains `roberta`: RobertaConfig (RoBERTa model)
|
|
||||||
- contains `camembert`: CamembertConfig (CamemBERT model)
|
|
||||||
- contains `ctrl` : CTRLConfig (CTRL model)
|
- contains `ctrl` : CTRLConfig (CTRL model)
|
||||||
This class cannot be instantiated using `__init__()` (throw an error).
|
This class cannot be instantiated using `__init__()` (throw an error).
|
||||||
"""
|
"""
|
||||||
@@ -67,14 +69,15 @@ class AutoConfig(object):
|
|||||||
The configuration class to instantiate is selected as the first pattern matching
|
The configuration class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertConfig (DistilBERT model)
|
- contains `distilbert`: DistilBertConfig (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertConfig (ALBERT model)
|
||||||
|
- contains `camembert`: CamembertConfig (CamemBERT model)
|
||||||
|
- contains `roberta`: RobertaConfig (RoBERTa model)
|
||||||
- contains `bert`: BertConfig (Bert model)
|
- contains `bert`: BertConfig (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Config (OpenAI GPT-2 model)
|
||||||
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLConfig (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetConfig (XLNet model)
|
- contains `xlnet`: XLNetConfig (XLNet model)
|
||||||
- contains `xlm`: XLMConfig (XLM model)
|
- contains `xlm`: XLMConfig (XLM model)
|
||||||
- contains `roberta`: RobertaConfig (RoBERTa model)
|
|
||||||
- contains `camembert`: CamembertConfig (CamemBERT model)
|
|
||||||
- contains `ctrl` : CTRLConfig (CTRL model)
|
- contains `ctrl` : CTRLConfig (CTRL model)
|
||||||
Params:
|
Params:
|
||||||
pretrained_model_name_or_path: either:
|
pretrained_model_name_or_path: either:
|
||||||
@@ -122,6 +125,8 @@ class AutoConfig(object):
|
|||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
elif 'albert' in pretrained_model_name_or_path:
|
||||||
|
return AlbertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif 'camembert' in pretrained_model_name_or_path:
|
||||||
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
@@ -142,4 +147,4 @@ class AutoConfig(object):
|
|||||||
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm', 'roberta', 'camembert', 'ctrl'".format(pretrained_model_name_or_path))
|
"'xlm', 'roberta', 'distilbert', 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassifica
|
|||||||
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
|
from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||||
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
|
from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification
|
||||||
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice
|
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice
|
||||||
|
from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertForSequenceClassification, CamembertForMultipleChoice
|
||||||
|
from .modeling_albert import AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, AlbertForQuestionAnswering
|
||||||
|
|
||||||
from .modeling_utils import PreTrainedModel, SequenceSummary
|
from .modeling_utils import PreTrainedModel, SequenceSummary
|
||||||
|
|
||||||
@@ -49,15 +51,16 @@ class AutoModel(object):
|
|||||||
The base model class to instantiate is selected as the first pattern matching
|
The base model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertModel (ALBERT model)
|
||||||
- contains `camembert`: CamembertModel (CamemBERT model)
|
- contains `camembert`: CamembertModel (CamemBERT model)
|
||||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||||
- contains `bert`: BertModel (Bert model)
|
- contains `bert`: BertModel (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||||
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
|
|
||||||
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetModel (XLNet model)
|
- contains `xlnet`: XLNetModel (XLNet model)
|
||||||
- contains `xlm`: XLMModel (XLM model)
|
- contains `xlm`: XLMModel (XLM model)
|
||||||
|
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
|
||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
@@ -73,15 +76,16 @@ class AutoModel(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
- contains `distilbert`: DistilBertModel (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertModel (ALBERT model)
|
||||||
- contains `camembert`: CamembertModel (CamemBERT model)
|
- contains `camembert`: CamembertModel (CamemBERT model)
|
||||||
- contains `roberta`: RobertaModel (RoBERTa model)
|
- contains `roberta`: RobertaModel (RoBERTa model)
|
||||||
- contains `bert`: BertModel (Bert model)
|
- contains `bert`: BertModel (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTModel (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Model (OpenAI GPT-2 model)
|
||||||
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
|
|
||||||
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLModel (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetModel (XLNet model)
|
- contains `xlnet`: XLNetModel (XLNet model)
|
||||||
- contains `xlm`: XLMModel (XLM model)
|
- contains `xlm`: XLMModel (XLM model)
|
||||||
|
- contains `ctrl`: CTRLModel (Salesforce CTRL model)
|
||||||
|
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
To train the model, you should first set it back in training mode with `model.train()`
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
@@ -144,6 +148,8 @@ class AutoModel(object):
|
|||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'albert' in pretrained_model_name_or_path:
|
||||||
|
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif 'camembert' in pretrained_model_name_or_path:
|
||||||
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
@@ -164,7 +170,7 @@ class AutoModel(object):
|
|||||||
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm', 'roberta, 'ctrl'".format(pretrained_model_name_or_path))
|
"'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
||||||
|
|
||||||
|
|
||||||
class AutoModelWithLMHead(object):
|
class AutoModelWithLMHead(object):
|
||||||
@@ -180,15 +186,16 @@ class AutoModelWithLMHead(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
|
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertForMaskedLM (ALBERT model)
|
||||||
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
|
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
|
||||||
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
|
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
|
||||||
- contains `bert`: BertForMaskedLM (Bert model)
|
- contains `bert`: BertForMaskedLM (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTLMHeadModel (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2LMHeadModel (OpenAI GPT-2 model)
|
||||||
- contains `ctrl`: CTRLLMModel (Salesforce CTRL model)
|
|
||||||
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
|
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
|
||||||
- contains `xlm`: XLMWithLMHeadModel (XLM model)
|
- contains `xlm`: XLMWithLMHeadModel (XLM model)
|
||||||
|
- contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
|
||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throws an error).
|
This class cannot be instantiated using `__init__()` (throws an error).
|
||||||
"""
|
"""
|
||||||
@@ -207,6 +214,7 @@ class AutoModelWithLMHead(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
|
- contains `distilbert`: DistilBertForMaskedLM (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertForMaskedLM (ALBERT model)
|
||||||
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
|
- contains `camembert`: CamembertForMaskedLM (CamemBERT model)
|
||||||
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
|
- contains `roberta`: RobertaForMaskedLM (RoBERTa model)
|
||||||
- contains `bert`: BertForMaskedLM (Bert model)
|
- contains `bert`: BertForMaskedLM (Bert model)
|
||||||
@@ -215,6 +223,7 @@ class AutoModelWithLMHead(object):
|
|||||||
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLLMHeadModel (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
|
- contains `xlnet`: XLNetLMHeadModel (XLNet model)
|
||||||
- contains `xlm`: XLMWithLMHeadModel (XLM model)
|
- contains `xlm`: XLMWithLMHeadModel (XLM model)
|
||||||
|
- contains `ctrl`: CTRLLMHeadModel (Salesforce CTRL model)
|
||||||
|
|
||||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||||
To train the model, you should first set it back in training mode with `model.train()`
|
To train the model, you should first set it back in training mode with `model.train()`
|
||||||
@@ -276,6 +285,8 @@ class AutoModelWithLMHead(object):
|
|||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'albert' in pretrained_model_name_or_path:
|
||||||
|
return AlbertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif 'camembert' in pretrained_model_name_or_path:
|
||||||
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
@@ -296,7 +307,7 @@ class AutoModelWithLMHead(object):
|
|||||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm', 'roberta','ctrl'".format(pretrained_model_name_or_path))
|
"'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForSequenceClassification(object):
|
class AutoModelForSequenceClassification(object):
|
||||||
@@ -312,6 +323,7 @@ class AutoModelForSequenceClassification(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
|
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertForSequenceClassification (ALBERT model)
|
||||||
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
|
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
|
||||||
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
|
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
|
||||||
- contains `bert`: BertForSequenceClassification (Bert model)
|
- contains `bert`: BertForSequenceClassification (Bert model)
|
||||||
@@ -335,6 +347,7 @@ class AutoModelForSequenceClassification(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
|
- contains `distilbert`: DistilBertForSequenceClassification (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertForSequenceClassification (ALBERT model)
|
||||||
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
|
- contains `camembert`: CamembertForSequenceClassification (CamemBERT model)
|
||||||
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
|
- contains `roberta`: RobertaForSequenceClassification (RoBERTa model)
|
||||||
- contains `bert`: BertForSequenceClassification (Bert model)
|
- contains `bert`: BertForSequenceClassification (Bert model)
|
||||||
@@ -402,6 +415,8 @@ class AutoModelForSequenceClassification(object):
|
|||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'albert' in pretrained_model_name_or_path:
|
||||||
|
return AlbertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif 'camembert' in pretrained_model_name_or_path:
|
||||||
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return CamembertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
@@ -414,7 +429,7 @@ class AutoModelForSequenceClassification(object):
|
|||||||
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMForSequenceClassification.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path))
|
"'bert', 'xlnet', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(pretrained_model_name_or_path))
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForQuestionAnswering(object):
|
class AutoModelForQuestionAnswering(object):
|
||||||
@@ -430,6 +445,7 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
|
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertForQuestionAnswering (ALBERT model)
|
||||||
- contains `bert`: BertForQuestionAnswering (Bert model)
|
- contains `bert`: BertForQuestionAnswering (Bert model)
|
||||||
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
|
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
|
||||||
- contains `xlm`: XLMForQuestionAnswering (XLM model)
|
- contains `xlm`: XLMForQuestionAnswering (XLM model)
|
||||||
@@ -451,6 +467,7 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
The model class to instantiate is selected as the first pattern matching
|
The model class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
|
- contains `distilbert`: DistilBertForQuestionAnswering (DistilBERT model)
|
||||||
|
- contains `albert`: AlbertForQuestionAnswering (ALBERT model)
|
||||||
- contains `bert`: BertForQuestionAnswering (Bert model)
|
- contains `bert`: BertForQuestionAnswering (Bert model)
|
||||||
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
|
- contains `xlnet`: XLNetForQuestionAnswering (XLNet model)
|
||||||
- contains `xlm`: XLMForQuestionAnswering (XLM model)
|
- contains `xlm`: XLMForQuestionAnswering (XLM model)
|
||||||
@@ -513,6 +530,8 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return DistilBertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
elif 'albert' in pretrained_model_name_or_path:
|
||||||
|
return AlbertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'bert' in pretrained_model_name_or_path:
|
elif 'bert' in pretrained_model_name_or_path:
|
||||||
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return BertForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
elif 'xlnet' in pretrained_model_name_or_path:
|
elif 'xlnet' in pretrained_model_name_or_path:
|
||||||
@@ -521,4 +540,4 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
return XLMForQuestionAnswering.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||||
|
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path))
|
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path))
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from .tokenization_xlm import XLMTokenizer
|
|||||||
from .tokenization_roberta import RobertaTokenizer
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
from .tokenization_distilbert import DistilBertTokenizer
|
from .tokenization_distilbert import DistilBertTokenizer
|
||||||
from .tokenization_camembert import CamembertTokenizer
|
from .tokenization_camembert import CamembertTokenizer
|
||||||
|
from .tokenization_albert import AlbertTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,16 +43,17 @@ class AutoTokenizer(object):
|
|||||||
|
|
||||||
The tokenizer class to instantiate is selected as the first pattern matching
|
The tokenizer class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
|
||||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||||
|
- contains `albert`: AlbertTokenizer (ALBERT model)
|
||||||
|
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||||
- contains `bert`: BertTokenizer (Bert model)
|
- contains `bert`: BertTokenizer (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
|
||||||
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||||
- contains `xlm`: XLMTokenizer (XLM model)
|
- contains `xlm`: XLMTokenizer (XLM model)
|
||||||
|
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||||
|
|
||||||
This class cannot be instantiated using `__init__()` (throw an error).
|
This class cannot be instantiated using `__init__()` (throw an error).
|
||||||
"""
|
"""
|
||||||
@@ -66,16 +68,17 @@ class AutoTokenizer(object):
|
|||||||
|
|
||||||
The tokenizer class to instantiate is selected as the first pattern matching
|
The tokenizer class to instantiate is selected as the first pattern matching
|
||||||
in the `pretrained_model_name_or_path` string (in the following order):
|
in the `pretrained_model_name_or_path` string (in the following order):
|
||||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
|
||||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||||
|
- contains `albert`: AlbertTokenizer (ALBERT model)
|
||||||
|
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||||
- contains `bert`: BertTokenizer (Bert model)
|
- contains `bert`: BertTokenizer (Bert model)
|
||||||
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||||
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
|
||||||
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||||
- contains `xlm`: XLMTokenizer (XLM model)
|
- contains `xlm`: XLMTokenizer (XLM model)
|
||||||
|
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||||
|
|
||||||
Params:
|
Params:
|
||||||
pretrained_model_name_or_path: either:
|
pretrained_model_name_or_path: either:
|
||||||
@@ -109,6 +112,8 @@ class AutoTokenizer(object):
|
|||||||
"""
|
"""
|
||||||
if 'distilbert' in pretrained_model_name_or_path:
|
if 'distilbert' in pretrained_model_name_or_path:
|
||||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
elif 'albert' in pretrained_model_name_or_path:
|
||||||
|
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
elif 'camembert' in pretrained_model_name_or_path:
|
elif 'camembert' in pretrained_model_name_or_path:
|
||||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
elif 'roberta' in pretrained_model_name_or_path:
|
elif 'roberta' in pretrained_model_name_or_path:
|
||||||
@@ -129,4 +134,4 @@ class AutoTokenizer(object):
|
|||||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
"'xlm', 'roberta', 'camembert', 'ctrl'".format(pretrained_model_name_or_path))
|
"'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(pretrained_model_name_or_path))
|
||||||
|
|||||||
Reference in New Issue
Block a user