Map configs to models and tokenizers
This commit is contained in:
@@ -202,7 +202,7 @@ class AutoConfig:
|
|||||||
return config_class.from_dict(config_dict, **kwargs)
|
return config_class.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should have a `model_type` key in its config.json, or contain one of {}".format(
|
"Unrecognized model in {}. "
|
||||||
pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys())
|
"Should have a `model_type` key in its config.json, or contain one of the following strings "
|
||||||
)
|
"in its name: {}".format(pretrained_model_name_or_path, ", ".join(CONFIG_MAPPING.keys()))
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ class PretrainedConfig(object):
|
|||||||
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
||||||
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = {} # type: Dict[str, str]
|
pretrained_config_archive_map = {} # type: Dict[str, str]
|
||||||
model_type = "" # type: str
|
model_type = "" # type: str
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Attributes with defaults
|
# Attributes with defaults
|
||||||
@@ -273,7 +273,7 @@ class PretrainedConfig(object):
|
|||||||
return self.__dict__ == other.__dict__
|
return self.__dict__ == other.__dict__
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return str(self.to_json_string())
|
return "{} {}".format(self.__class__.__name__, self.to_json_string())
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""Serializes this instance to a Python dictionary."""
|
"""Serializes this instance to a Python dictionary."""
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
@@ -126,14 +126,14 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
|||||||
for key, value, in pretrained_map.items()
|
for key, value, in pretrained_map.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
MODEL_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
[
|
[
|
||||||
(T5Config, T5Model),
|
(T5Config, T5Model),
|
||||||
(DistilBertConfig, DistilBertModel),
|
(DistilBertConfig, DistilBertModel),
|
||||||
(AlbertConfig, AlbertModel),
|
(AlbertConfig, AlbertModel),
|
||||||
(CamembertConfig, CamembertModel),
|
(CamembertConfig, CamembertModel),
|
||||||
(RobertaConfig, XLMRobertaModel),
|
(RobertaConfig, RobertaModel),
|
||||||
(XLMRobertaConfig, RobertaModel),
|
(XLMRobertaConfig, XLMRobertaModel),
|
||||||
(BertConfig, BertModel),
|
(BertConfig, BertModel),
|
||||||
(OpenAIGPTConfig, OpenAIGPTModel),
|
(OpenAIGPTConfig, OpenAIGPTModel),
|
||||||
(GPT2Config, GPT2Model),
|
(GPT2Config, GPT2Model),
|
||||||
@@ -144,12 +144,53 @@ MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = Orde
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
MODEL_WITH_LM_HEAD_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(T5Config, T5WithLMHeadModel),
|
||||||
|
(DistilBertConfig, DistilBertForMaskedLM),
|
||||||
|
(AlbertConfig, AlbertForMaskedLM),
|
||||||
|
(CamembertConfig, CamembertForMaskedLM),
|
||||||
|
(RobertaConfig, RobertaForMaskedLM),
|
||||||
|
(XLMRobertaConfig, XLMRobertaForMaskedLM),
|
||||||
|
(BertConfig, BertForMaskedLM),
|
||||||
|
(OpenAIGPTConfig, OpenAIGPTLMHeadModel),
|
||||||
|
(GPT2Config, GPT2LMHeadModel),
|
||||||
|
(TransfoXLConfig, TransfoXLLMHeadModel),
|
||||||
|
(XLNetConfig, XLNetLMHeadModel),
|
||||||
|
(XLMConfig, XLMWithLMHeadModel),
|
||||||
|
(CTRLConfig, CTRLLMHeadModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, DistilBertForSequenceClassification),
|
||||||
|
(AlbertConfig, AlbertForSequenceClassification),
|
||||||
|
(CamembertConfig, CamembertForSequenceClassification),
|
||||||
|
(RobertaConfig, RobertaForSequenceClassification),
|
||||||
|
(XLMRobertaConfig, XLMRobertaForSequenceClassification),
|
||||||
|
(BertConfig, BertForSequenceClassification),
|
||||||
|
(XLNetConfig, XLNetForSequenceClassification),
|
||||||
|
(XLMConfig, XLMForSequenceClassification),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, DistilBertForQuestionAnswering),
|
||||||
|
(AlbertConfig, AlbertForQuestionAnswering),
|
||||||
|
(BertConfig, BertForQuestionAnswering),
|
||||||
|
(XLNetConfig, XLNetForQuestionAnswering),
|
||||||
|
(XLMConfig, XLMForQuestionAnswering),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
[
|
[
|
||||||
(DistilBertConfig, DistilBertForTokenClassification),
|
(DistilBertConfig, DistilBertForTokenClassification),
|
||||||
(CamembertConfig, CamembertForTokenClassification),
|
(CamembertConfig, CamembertForTokenClassification),
|
||||||
(RobertaConfig, XLMRobertaForTokenClassification),
|
(RobertaConfig, RobertaForTokenClassification),
|
||||||
(XLMRobertaConfig, RobertaForTokenClassification),
|
(XLMRobertaConfig, XLMRobertaForTokenClassification),
|
||||||
(BertConfig, BertForTokenClassification),
|
(BertConfig, BertForTokenClassification),
|
||||||
(XLNetConfig, XLNetForTokenClassification),
|
(XLNetConfig, XLNetForTokenClassification),
|
||||||
]
|
]
|
||||||
@@ -218,7 +259,12 @@ class AutoModel(object):
|
|||||||
for config_class, model_class in MODEL_MAPPING.items():
|
for config_class, model_class in MODEL_MAPPING.items():
|
||||||
if isinstance(config, config_class):
|
if isinstance(config, config_class):
|
||||||
return model_class(config)
|
return model_class(config)
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
raise ValueError(
|
||||||
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -309,10 +355,9 @@ class AutoModel(object):
|
|||||||
if isinstance(config, config_class):
|
if isinstance(config, config_class):
|
||||||
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"Model type should be one of {}.".format(
|
||||||
"'xlm-roberta', 'xlm', 'roberta, 'ctrl', 'distilbert', 'camembert', 'albert'".format(
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||||
pretrained_model_name_or_path
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -376,27 +421,15 @@ class AutoModelWithLMHead(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||||
return DistilBertForMaskedLM(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, RobertaConfig):
|
return model_class(config)
|
||||||
return RobertaForMaskedLM(config)
|
raise ValueError(
|
||||||
elif isinstance(config, BertConfig):
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
return BertForMaskedLM(config)
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||||
return OpenAIGPTLMHeadModel(config)
|
)
|
||||||
elif isinstance(config, GPT2Config):
|
)
|
||||||
return GPT2LMHeadModel(config)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TransfoXLLMHeadModel(config)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetLMHeadModel(config)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMWithLMHeadModel(config)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return CTRLLMHeadModel(config)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaForMaskedLM(config)
|
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -486,57 +519,13 @@ class AutoModelWithLMHead(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, T5Config):
|
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||||
return T5WithLMHeadModel.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, DistilBertConfig):
|
|
||||||
return DistilBertForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return AlbertForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, CamembertConfig):
|
|
||||||
return CamembertForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return RobertaForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertForMaskedLM.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
|
||||||
return OpenAIGPTLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, GPT2Config):
|
|
||||||
return GPT2LMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TransfoXLLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMWithLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return CTRLLMHeadModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"Model type should be one of {}.".format(
|
||||||
"'xlm-roberta', 'xlm', 'roberta','ctrl', 'distilbert', 'camembert', 'albert'".format(
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||||
pretrained_model_name_or_path
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -591,23 +580,17 @@ class AutoModelForSequenceClassification(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, AlbertConfig):
|
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||||
return AlbertForSequenceClassification(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, CamembertConfig):
|
return model_class(config)
|
||||||
return CamembertForSequenceClassification(config)
|
raise ValueError(
|
||||||
elif isinstance(config, DistilBertConfig):
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
return DistilBertForSequenceClassification(config)
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, RobertaConfig):
|
config.__class__,
|
||||||
return RobertaForSequenceClassification(config)
|
cls.__name__,
|
||||||
elif isinstance(config, BertConfig):
|
", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||||
return BertForSequenceClassification(config)
|
)
|
||||||
elif isinstance(config, XLNetConfig):
|
)
|
||||||
return XLNetForSequenceClassification(config)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMForSequenceClassification(config)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaForSequenceClassification(config)
|
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -693,43 +676,15 @@ class AutoModelForSequenceClassification(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||||
return DistilBertForSequenceClassification.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return AlbertForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, CamembertConfig):
|
|
||||||
return CamembertForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return RobertaForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
"'bert', 'xlnet', 'xlm-roberta', 'xlm', 'roberta', 'distilbert', 'camembert', 'albert'".format(
|
"Model type should be one of {}.".format(
|
||||||
pretrained_model_name_or_path
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -780,17 +735,18 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, AlbertConfig):
|
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||||
return AlbertForQuestionAnswering(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, DistilBertConfig):
|
return model_class(config)
|
||||||
return DistilBertForQuestionAnswering(config)
|
|
||||||
elif isinstance(config, BertConfig):
|
raise ValueError(
|
||||||
return BertForQuestionAnswering(config)
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
elif isinstance(config, XLNetConfig):
|
"Model type should be one of {}.".format(
|
||||||
return XLNetForQuestionAnswering(config)
|
config.__class__,
|
||||||
elif isinstance(config, XLMConfig):
|
cls.__name__,
|
||||||
return XLMForQuestionAnswering(config)
|
", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -870,30 +826,17 @@ class AutoModelForQuestionAnswering(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||||
return DistilBertForQuestionAnswering.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return AlbertForQuestionAnswering.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertForQuestionAnswering.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetForQuestionAnswering.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMForQuestionAnswering.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
"'bert', 'xlnet', 'xlm', 'distilbert', 'albert'".format(pretrained_model_name_or_path)
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -923,19 +866,18 @@ class AutoModelForTokenClassification:
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, CamembertConfig):
|
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||||
return CamembertForTokenClassification(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, DistilBertConfig):
|
return model_class(config)
|
||||||
return DistilBertForTokenClassification(config)
|
|
||||||
elif isinstance(config, BertConfig):
|
raise ValueError(
|
||||||
return BertForTokenClassification(config)
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
elif isinstance(config, XLNetConfig):
|
"Model type should be one of {}.".format(
|
||||||
return XLNetForTokenClassification(config)
|
config.__class__,
|
||||||
elif isinstance(config, RobertaConfig):
|
cls.__name__,
|
||||||
return RobertaForTokenClassification(config)
|
", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
)
|
||||||
return XLMRobertaForTokenClassification(config)
|
)
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -1014,34 +956,15 @@ class AutoModelForTokenClassification:
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, CamembertConfig):
|
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||||
return CamembertForTokenClassification.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, DistilBertConfig):
|
|
||||||
return DistilBertForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return RobertaForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
|
||||||
"'bert', 'xlnet', 'camembert', 'distilbert', 'xlm-roberta', 'roberta'".format(
|
"Model type should be one of {}.".format(
|
||||||
pretrained_model_name_or_path
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
@@ -70,6 +72,7 @@ from .modeling_tf_transfo_xl import (
|
|||||||
TFTransfoXLLMHeadModel,
|
TFTransfoXLLMHeadModel,
|
||||||
TFTransfoXLModel,
|
TFTransfoXLModel,
|
||||||
)
|
)
|
||||||
|
from .modeling_tf_utils import TFPreTrainedModel
|
||||||
from .modeling_tf_xlm import (
|
from .modeling_tf_xlm import (
|
||||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
TFXLMForQuestionAnsweringSimple,
|
TFXLMForQuestionAnsweringSimple,
|
||||||
@@ -108,6 +111,65 @@ TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
|||||||
for key, value, in pretrained_map.items()
|
for key, value, in pretrained_map.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
TF_MODEL_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, TFDistilBertModel),
|
||||||
|
(AlbertConfig, TFAlbertModel),
|
||||||
|
(RobertaConfig, TFRobertaModel),
|
||||||
|
(BertConfig, TFBertModel),
|
||||||
|
(OpenAIGPTConfig, TFOpenAIGPTModel),
|
||||||
|
(GPT2Config, TFGPT2Model),
|
||||||
|
(TransfoXLConfig, TFTransfoXLModel),
|
||||||
|
(XLNetConfig, TFXLNetModel),
|
||||||
|
(XLMConfig, TFXLMModel),
|
||||||
|
(CTRLConfig, TFCTRLModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
TF_MODEL_WITH_LM_HEAD_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, TFDistilBertForMaskedLM),
|
||||||
|
(AlbertConfig, TFAlbertForMaskedLM),
|
||||||
|
(RobertaConfig, TFRobertaForMaskedLM),
|
||||||
|
(BertConfig, TFBertForMaskedLM),
|
||||||
|
(OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
|
||||||
|
(GPT2Config, TFGPT2LMHeadModel),
|
||||||
|
(TransfoXLConfig, TFTransfoXLLMHeadModel),
|
||||||
|
(XLNetConfig, TFXLNetLMHeadModel),
|
||||||
|
(XLMConfig, TFXLMWithLMHeadModel),
|
||||||
|
(CTRLConfig, TFCTRLLMHeadModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, TFDistilBertForSequenceClassification),
|
||||||
|
(AlbertConfig, TFAlbertForSequenceClassification),
|
||||||
|
(RobertaConfig, TFRobertaForSequenceClassification),
|
||||||
|
(BertConfig, TFBertForSequenceClassification),
|
||||||
|
(XLNetConfig, TFXLNetForSequenceClassification),
|
||||||
|
(XLMConfig, TFXLMForSequenceClassification),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, TFDistilBertForQuestionAnswering),
|
||||||
|
(BertConfig, TFBertForQuestionAnswering),
|
||||||
|
(XLNetConfig, TFXLNetForQuestionAnsweringSimple),
|
||||||
|
(XLMConfig, TFXLMForQuestionAnsweringSimple),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: Dict[Type[PretrainedConfig], Type[TFPreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, TFDistilBertForTokenClassification),
|
||||||
|
(RobertaConfig, TFRobertaForTokenClassification),
|
||||||
|
(BertConfig, TFBertForTokenClassification),
|
||||||
|
(XLNetConfig, TFXLNetForTokenClassification),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TFAutoModel(object):
|
class TFAutoModel(object):
|
||||||
r"""
|
r"""
|
||||||
@@ -165,25 +227,15 @@ class TFAutoModel(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = TFAutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in TF_MODEL_MAPPING.items():
|
||||||
return TFDistilBertModel(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, RobertaConfig):
|
return model_class(config)
|
||||||
return TFRobertaModel(config)
|
raise ValueError(
|
||||||
elif isinstance(config, BertConfig):
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
return TFBertModel(config)
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys())
|
||||||
return TFOpenAIGPTModel(config)
|
)
|
||||||
elif isinstance(config, GPT2Config):
|
)
|
||||||
return TFGPT2Model(config)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TFTransfoXLModel(config)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetModel(config)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return TFXLMModel(config)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return TFCTRLModel(config)
|
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -266,39 +318,14 @@ class TFAutoModel(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, T5Config):
|
for config_class, model_class in TF_MODEL_MAPPING.items():
|
||||||
return TFT5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, DistilBertConfig):
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
return TFDistilBertModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return TFAlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return TFRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return TFBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
|
||||||
return TFOpenAIGPTModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, GPT2Config):
|
|
||||||
return TFGPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TFTransfoXLModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return TFXLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return TFCTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"Model type should be one of {}.".format(
|
||||||
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_MAPPING.keys())
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -358,25 +385,15 @@ class TFAutoModelWithLMHead(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelWithLMHead.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||||
return TFDistilBertForMaskedLM(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, RobertaConfig):
|
return model_class(config)
|
||||||
return TFRobertaForMaskedLM(config)
|
raise ValueError(
|
||||||
elif isinstance(config, BertConfig):
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
return TFBertForMaskedLM(config)
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||||
return TFOpenAIGPTLMHeadModel(config)
|
)
|
||||||
elif isinstance(config, GPT2Config):
|
)
|
||||||
return TFGPT2LMHeadModel(config)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TFTransfoXLLMHeadModel(config)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetLMHeadModel(config)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return TFXLMWithLMHeadModel(config)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return TFCTRLLMHeadModel(config)
|
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -464,55 +481,14 @@ class TFAutoModelWithLMHead(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, T5Config):
|
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
|
||||||
return TFT5WithLMHeadModel.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, DistilBertConfig):
|
|
||||||
return TFDistilBertForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return TFAlbertForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return TFRobertaForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return TFBertForMaskedLM.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
|
||||||
return TFOpenAIGPTLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, GPT2Config):
|
|
||||||
return TFGPT2LMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TFTransfoXLLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return TFXLMWithLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return TFCTRLLMHeadModel.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
"'distilbert', 'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"Model type should be one of {}.".format(
|
||||||
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)
|
config.__class__, cls.__name__, ", ".join(c.__name__ for c in TF_MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -563,17 +539,17 @@ class TFAutoModelForSequenceClassification(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||||
return TFDistilBertForSequenceClassification(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, RobertaConfig):
|
return model_class(config)
|
||||||
return TFRobertaForSequenceClassification(config)
|
raise ValueError(
|
||||||
elif isinstance(config, BertConfig):
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
return TFBertForSequenceClassification(config)
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, XLNetConfig):
|
config.__class__,
|
||||||
return TFXLNetForSequenceClassification(config)
|
cls.__name__,
|
||||||
elif isinstance(config, XLMConfig):
|
", ".join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||||
return TFXLMForSequenceClassification(config)
|
)
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -659,34 +635,16 @@ class TFAutoModelForSequenceClassification(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
|
||||||
return TFDistilBertForSequenceClassification.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return TFAlbertForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return TFRobertaForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return TFBertForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return TFXLMForSequenceClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
"'distilbert', 'bert', 'xlnet', 'xlm', 'roberta'".format(pretrained_model_name_or_path)
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -735,15 +693,17 @@ class TFAutoModelForQuestionAnswering(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModelForSequenceClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||||
return TFDistilBertForQuestionAnswering(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, BertConfig):
|
return model_class(config)
|
||||||
return TFBertForQuestionAnswering(config)
|
raise ValueError(
|
||||||
elif isinstance(config, XLNetConfig):
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
raise NotImplementedError("TFXLNetForQuestionAnswering isn't implemented")
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, XLMConfig):
|
config.__class__,
|
||||||
raise NotImplementedError("TFXLMForQuestionAnswering isn't implemented")
|
cls.__name__,
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
", ".join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -828,26 +788,16 @@ class TFAutoModelForQuestionAnswering(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
|
||||||
return TFDistilBertForQuestionAnswering.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return TFBertForQuestionAnswering.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetForQuestionAnsweringSimple.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return TFXLMForQuestionAnsweringSimple.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
"'distilbert', 'bert', 'xlnet', 'xlm'".format(pretrained_model_name_or_path)
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -876,15 +826,17 @@ class TFAutoModelForTokenClassification:
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = TFAutoModelForTokenClassification.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, BertConfig):
|
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||||
return TFBertForTokenClassification(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, XLNetConfig):
|
return model_class(config)
|
||||||
return TFXLNetForTokenClassification(config)
|
raise ValueError(
|
||||||
elif isinstance(config, DistilBertConfig):
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
return TFDistilBertForTokenClassification(config)
|
"Model type should be one of {}.".format(
|
||||||
elif isinstance(config, RobertaConfig):
|
config.__class__,
|
||||||
return TFRobertaForTokenClassification(config)
|
cls.__name__,
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
", ".join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
@@ -962,24 +914,14 @@ class TFAutoModelForTokenClassification:
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, BertConfig):
|
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
|
||||||
return TFBertForTokenClassification.from_pretrained(
|
if isinstance(config, config_class):
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return TFXLNetForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, DistilBertConfig):
|
|
||||||
return TFDistilBertForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return TFRobertaForTokenClassification.from_pretrained(
|
|
||||||
pretrained_model_name_or_path, *model_args, config=config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
|
||||||
"'bert', 'xlnet', 'distilbert', 'roberta'".format(pretrained_model_name_or_path)
|
"Model type should be one of {}.".format(
|
||||||
|
config.__class__,
|
||||||
|
cls.__name__,
|
||||||
|
", ".join(c.__name__ for c in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
@@ -45,6 +47,7 @@ from .tokenization_openai import OpenAIGPTTokenizer
|
|||||||
from .tokenization_roberta import RobertaTokenizer
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
from .tokenization_t5 import T5Tokenizer
|
from .tokenization_t5 import T5Tokenizer
|
||||||
from .tokenization_transfo_xl import TransfoXLTokenizer
|
from .tokenization_transfo_xl import TransfoXLTokenizer
|
||||||
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_xlm import XLMTokenizer
|
from .tokenization_xlm import XLMTokenizer
|
||||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
from .tokenization_xlnet import XLNetTokenizer
|
from .tokenization_xlnet import XLNetTokenizer
|
||||||
@@ -53,6 +56,25 @@ from .tokenization_xlnet import XLNetTokenizer
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
TOKENIZER_MAPPING: Dict[Type[PretrainedConfig], Type[PreTrainedTokenizer]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(T5Config, T5Tokenizer),
|
||||||
|
(DistilBertConfig, DistilBertTokenizer),
|
||||||
|
(AlbertConfig, AlbertTokenizer),
|
||||||
|
(CamembertConfig, CamembertTokenizer),
|
||||||
|
(RobertaConfig, XLMRobertaTokenizer),
|
||||||
|
(XLMRobertaConfig, RobertaTokenizer),
|
||||||
|
(BertConfig, BertTokenizer),
|
||||||
|
(OpenAIGPTConfig, OpenAIGPTTokenizer),
|
||||||
|
(GPT2Config, GPT2Tokenizer),
|
||||||
|
(TransfoXLConfig, TransfoXLTokenizer),
|
||||||
|
(XLNetConfig, XLNetTokenizer),
|
||||||
|
(XLMConfig, XLMTokenizer),
|
||||||
|
(CTRLConfig, CTRLTokenizer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoTokenizer(object):
|
class AutoTokenizer(object):
|
||||||
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
|
r""":class:`~transformers.AutoTokenizer` is a generic tokenizer class
|
||||||
that will be instantiated as one of the tokenizer classes of the library
|
that will be instantiated as one of the tokenizer classes of the library
|
||||||
@@ -154,36 +176,13 @@ class AutoTokenizer(object):
|
|||||||
if "bert-base-japanese" in pretrained_model_name_or_path:
|
if "bert-base-japanese" in pretrained_model_name_or_path:
|
||||||
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, T5Config):
|
for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
|
||||||
return T5Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, DistilBertConfig):
|
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||||
return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return AlbertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, CamembertConfig):
|
|
||||||
return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
|
||||||
return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, GPT2Config):
|
|
||||||
return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"Model type should be one of {}.".format(
|
||||||
"'xlm-roberta', 'xlm', 'roberta', 'distilbert,' 'camembert', 'ctrl', 'albert'".format(
|
config.__class__, ", ".join(c.__name__ for c in MODEL_MAPPING.keys())
|
||||||
pretrained_model_name_or_path
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,4 +51,4 @@ class AutoConfigTest(unittest.TestCase):
|
|||||||
# no key string should be included in a later key string (typical failure case)
|
# no key string should be included in a later key string (typical failure case)
|
||||||
keys = list(CONFIG_MAPPING.keys())
|
keys = list(CONFIG_MAPPING.keys())
|
||||||
for i, key in enumerate(keys):
|
for i, key in enumerate(keys):
|
||||||
self.assertFalse(any(key in later_key for later_key in keys[i+1:]))
|
self.assertFalse(any(key in later_key for later_key in keys[i + 1 :]))
|
||||||
|
|||||||
Reference in New Issue
Block a user