Config to Model mapping
This commit is contained in:
@@ -62,6 +62,7 @@ class BertAbsConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
|
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
|
||||||
|
model_type = "bertabs"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ class AlbertConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "albert"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ class AutoConfig:
|
|||||||
assert unused_kwargs == {'foo': False}
|
assert unused_kwargs == {'foo': False}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict, _ = PretrainedConfig.resolved_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
|
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class BertConfig(PretrainedConfig):
|
|||||||
layer_norm_eps: The epsilon used by LayerNorm.
|
layer_norm_eps: The epsilon used by LayerNorm.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "bert"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -30,3 +30,4 @@ CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
|
|
||||||
class CamembertConfig(RobertaConfig):
|
class CamembertConfig(RobertaConfig):
|
||||||
pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "camembert"
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class CTRLConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "ctrl"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
|
|
||||||
class DistilBertConfig(PretrainedConfig):
|
class DistilBertConfig(PretrainedConfig):
|
||||||
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "distilbert"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class GPT2Config(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "gpt2"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class OpenAIGPTConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "openai-gpt"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class T5Config(PretrainedConfig):
|
|||||||
layer_norm_eps: The epsilon used by LayerNorm.
|
layer_norm_eps: The epsilon used by LayerNorm.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "t5"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "transfo-xl"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ class PretrainedConfig(object):
|
|||||||
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
``output_hidden_states``: string, default `False`. Should the model returns all hidden-states.
|
||||||
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
``torchscript``: string, default `False`. Is the model used with Torchscript.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = {}
|
pretrained_config_archive_map: Dict[str, str] = {}
|
||||||
|
model_type: str
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# Attributes with defaults
|
# Attributes with defaults
|
||||||
@@ -155,11 +156,11 @@ class PretrainedConfig(object):
|
|||||||
assert unused_kwargs == {'foo': False}
|
assert unused_kwargs == {'foo': False}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict, kwargs = cls.resolved_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolved_config_dict(
|
def get_config_dict(
|
||||||
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
|
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
|
||||||
) -> Tuple[Dict, Dict]:
|
) -> Tuple[Dict, Dict]:
|
||||||
"""
|
"""
|
||||||
@@ -257,7 +258,7 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json_file(cls, json_file: str):
|
def from_json_file(cls, json_file: str):
|
||||||
"""Constructs a `Config` from a json file of parameters."""
|
"""Constructs a `Config` from the path to a json file of parameters."""
|
||||||
config_dict = cls._dict_from_json_file(json_file)
|
config_dict = cls._dict_from_json_file(json_file)
|
||||||
return cls(**config_dict)
|
return cls(**config_dict)
|
||||||
|
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class XLMConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "xlm"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -35,3 +35,4 @@ XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||||||
|
|
||||||
class XLMRobertaConfig(RobertaConfig):
|
class XLMRobertaConfig(RobertaConfig):
|
||||||
pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "xlm-roberta"
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ class XLNetConfig(PretrainedConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "xlnet"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
from .configuration_auto import (
|
from .configuration_auto import (
|
||||||
AlbertConfig,
|
AlbertConfig,
|
||||||
@@ -76,6 +78,7 @@ from .modeling_roberta import (
|
|||||||
)
|
)
|
||||||
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
|
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5Model, T5WithLMHeadModel
|
||||||
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
|
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
|
||||||
|
from .modeling_utils import PreTrainedModel
|
||||||
from .modeling_xlm import (
|
from .modeling_xlm import (
|
||||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||||
XLMForQuestionAnswering,
|
XLMForQuestionAnswering,
|
||||||
@@ -123,6 +126,35 @@ ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
|||||||
for key, value, in pretrained_map.items()
|
for key, value, in pretrained_map.items()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(T5Config, T5Model),
|
||||||
|
(DistilBertConfig, DistilBertModel),
|
||||||
|
(AlbertConfig, AlbertModel),
|
||||||
|
(CamembertConfig, CamembertModel),
|
||||||
|
(RobertaConfig, XLMRobertaModel),
|
||||||
|
(XLMRobertaConfig, RobertaModel),
|
||||||
|
(BertConfig, BertModel),
|
||||||
|
(OpenAIGPTConfig, OpenAIGPTModel),
|
||||||
|
(GPT2Config, GPT2Model),
|
||||||
|
(TransfoXLConfig, TransfoXLModel),
|
||||||
|
(XLNetConfig, XLNetModel),
|
||||||
|
(XLMConfig, XLMModel),
|
||||||
|
(CTRLConfig, CTRLModel),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING: OrderedDict[Type[PretrainedConfig], Type[PreTrainedModel]] = OrderedDict(
|
||||||
|
[
|
||||||
|
(DistilBertConfig, DistilBertForTokenClassification),
|
||||||
|
(CamembertConfig, CamembertForTokenClassification),
|
||||||
|
(RobertaConfig, XLMRobertaForTokenClassification),
|
||||||
|
(XLMRobertaConfig, RobertaForTokenClassification),
|
||||||
|
(BertConfig, BertForTokenClassification),
|
||||||
|
(XLNetConfig, XLNetForTokenClassification),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(object):
|
class AutoModel(object):
|
||||||
r"""
|
r"""
|
||||||
@@ -183,30 +215,9 @@ class AutoModel(object):
|
|||||||
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
||||||
model = AutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
model = AutoModel.from_config(config) # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
||||||
"""
|
"""
|
||||||
if isinstance(config, DistilBertConfig):
|
for config_class, model_class in MODEL_MAPPING.items():
|
||||||
return DistilBertModel(config)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, RobertaConfig):
|
return model_class(config)
|
||||||
return RobertaModel(config)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertModel(config)
|
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
|
||||||
return OpenAIGPTModel(config)
|
|
||||||
elif isinstance(config, GPT2Config):
|
|
||||||
return GPT2Model(config)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TransfoXLModel(config)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetModel(config)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMModel(config)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return CTRLModel(config)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return AlbertModel(config)
|
|
||||||
elif isinstance(config, CamembertConfig):
|
|
||||||
return CamembertModel(config)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaModel(config)
|
|
||||||
raise ValueError("Unrecognized configuration class {}".format(config))
|
raise ValueError("Unrecognized configuration class {}".format(config))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -294,32 +305,9 @@ class AutoModel(object):
|
|||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
if isinstance(config, T5Config):
|
for config_class, model_class in MODEL_MAPPING.items():
|
||||||
return T5Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
if isinstance(config, config_class):
|
||||||
elif isinstance(config, DistilBertConfig):
|
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
||||||
return DistilBertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, AlbertConfig):
|
|
||||||
return AlbertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, CamembertConfig):
|
|
||||||
return CamembertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, XLMRobertaConfig):
|
|
||||||
return XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, RobertaConfig):
|
|
||||||
return RobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, BertConfig):
|
|
||||||
return BertModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, OpenAIGPTConfig):
|
|
||||||
return OpenAIGPTModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, GPT2Config):
|
|
||||||
return GPT2Model.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, TransfoXLConfig):
|
|
||||||
return TransfoXLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, XLNetConfig):
|
|
||||||
return XLNetModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, XLMConfig):
|
|
||||||
return XLMModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
elif isinstance(config, CTRLConfig):
|
|
||||||
return CTRLModel.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unrecognized model identifier in {}. Should contains one of "
|
"Unrecognized model identifier in {}. Should contains one of "
|
||||||
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class XxxConfig(PretrainedConfig):
|
|||||||
layer_norm_eps: The epsilon used by LayerNorm.
|
layer_norm_eps: The epsilon used by LayerNorm.
|
||||||
"""
|
"""
|
||||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||||
|
model_type = "xxx"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user