Explicit config_class instead of module inspection
This commit is contained in:
@@ -101,20 +101,6 @@ class AutoConfig:
|
|||||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_class_for_model_class(cls, model_class):
|
|
||||||
module = import_module(model_class.__module__)
|
|
||||||
return next(
|
|
||||||
(
|
|
||||||
module_attribute
|
|
||||||
for module_attribute_name in dir(module)
|
|
||||||
if module_attribute_name.endswith("Config")
|
|
||||||
for module_attribute in (getattr(module, module_attribute_name),)
|
|
||||||
if issubclass(module_attribute, PretrainedConfig)
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_model(cls, model_type, *args, **kwargs):
|
def for_model(cls, model_type, *args, **kwargs):
|
||||||
for pattern, config_class in CONFIG_MAPPING.items():
|
for pattern, config_class in CONFIG_MAPPING.items():
|
||||||
|
|||||||
@@ -480,6 +480,8 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = AlbertConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@@ -473,6 +473,8 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFBertMainLayer(tf.keras.layers.Layer):
|
class TFBertMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = BertConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_hidden_layers = config.num_hidden_layers
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
|||||||
@@ -166,6 +166,8 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = CTRLConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|||||||
@@ -199,6 +199,8 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = GPT2Config
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|||||||
@@ -380,6 +380,8 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = TransfoXLConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
|
|||||||
@@ -50,11 +50,13 @@ class TFModelUtilsMixin:
|
|||||||
def keras_serializable(cls):
|
def keras_serializable(cls):
|
||||||
initializer = cls.__init__
|
initializer = cls.__init__
|
||||||
|
|
||||||
|
config_class = getattr(cls, "config_class", None)
|
||||||
|
if config_class is None:
|
||||||
|
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
||||||
|
|
||||||
def wrapped_init(self, config, *args, **kwargs):
|
def wrapped_init(self, config, *args, **kwargs):
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
from transformers import AutoConfig
|
config = config_class.from_dict(config)
|
||||||
|
|
||||||
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
|
|
||||||
initializer(self, config, *args, **kwargs)
|
initializer(self, config, *args, **kwargs)
|
||||||
self._transformers_config = config
|
self._transformers_config = config
|
||||||
|
|
||||||
|
|||||||
@@ -351,6 +351,8 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||||
|
config_class = XLNetConfig
|
||||||
|
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
|
|||||||
Reference in New Issue
Block a user