Explicit config_class instead of module inspection
This commit is contained in:
@@ -101,20 +101,6 @@ class AutoConfig:
|
||||
"using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def config_class_for_model_class(cls, model_class):
|
||||
module = import_module(model_class.__module__)
|
||||
return next(
|
||||
(
|
||||
module_attribute
|
||||
for module_attribute_name in dir(module)
|
||||
if module_attribute_name.endswith("Config")
|
||||
for module_attribute in (getattr(module, module_attribute_name),)
|
||||
if issubclass(module_attribute, PretrainedConfig)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_type, *args, **kwargs):
|
||||
for pattern, config_class in CONFIG_MAPPING.items():
|
||||
|
||||
@@ -480,6 +480,8 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
||||
|
||||
@keras_serializable
|
||||
class TFAlbertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = AlbertConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@@ -473,6 +473,8 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
||||
|
||||
@keras_serializable
|
||||
class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
config_class = BertConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@@ -166,6 +166,8 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
||||
|
||||
@keras_serializable
|
||||
class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
config_class = CTRLConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
@@ -199,6 +199,8 @@ class TFBlock(tf.keras.layers.Layer):
|
||||
|
||||
@keras_serializable
|
||||
class TFGPT2MainLayer(tf.keras.layers.Layer):
|
||||
config_class = GPT2Config
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(*inputs, **kwargs)
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
@@ -380,6 +380,8 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
|
||||
|
||||
@keras_serializable
|
||||
class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
config_class = TransfoXLConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
@@ -50,11 +50,13 @@ class TFModelUtilsMixin:
|
||||
def keras_serializable(cls):
|
||||
initializer = cls.__init__
|
||||
|
||||
config_class = getattr(cls, "config_class", None)
|
||||
if config_class is None:
|
||||
raise AttributeError("Must set `config_class` to use @keras_serializable")
|
||||
|
||||
def wrapped_init(self, config, *args, **kwargs):
|
||||
if isinstance(config, dict):
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.config_class_for_model_class(cls).from_dict(config)
|
||||
config = config_class.from_dict(config)
|
||||
initializer(self, config, *args, **kwargs)
|
||||
self._transformers_config = config
|
||||
|
||||
|
||||
@@ -351,6 +351,8 @@ class TFXLNetLMHead(tf.keras.layers.Layer):
|
||||
|
||||
@keras_serializable
|
||||
class TFXLNetMainLayer(tf.keras.layers.Layer):
|
||||
config_class = XLNetConfig
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.output_attentions = config.output_attentions
|
||||
|
||||
Reference in New Issue
Block a user