From 4f338ed407a0eca44cbdb342beba0c9f23bbea1e Mon Sep 17 00:00:00 2001 From: Gunnlaugur Thor Briem Date: Wed, 4 Mar 2020 23:45:29 +0000 Subject: [PATCH] Explicit config_class instead of module inspection --- src/transformers/configuration_auto.py | 14 -------------- src/transformers/modeling_tf_albert.py | 2 ++ src/transformers/modeling_tf_bert.py | 2 ++ src/transformers/modeling_tf_ctrl.py | 2 ++ src/transformers/modeling_tf_gpt2.py | 2 ++ src/transformers/modeling_tf_transfo_xl.py | 2 ++ src/transformers/modeling_tf_utils.py | 8 +++++--- src/transformers/modeling_tf_xlnet.py | 2 ++ 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index c363c89e3d..37fd2feae3 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -101,20 +101,6 @@ class AutoConfig: "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method." ) - @classmethod - def config_class_for_model_class(cls, model_class): - module = import_module(model_class.__module__) - return next( - ( - module_attribute - for module_attribute_name in dir(module) - if module_attribute_name.endswith("Config") - for module_attribute in (getattr(module, module_attribute_name),) - if issubclass(module_attribute, PretrainedConfig) - ), - None, - ) - @classmethod def for_model(cls, model_type, *args, **kwargs): for pattern, config_class in CONFIG_MAPPING.items(): diff --git a/src/transformers/modeling_tf_albert.py b/src/transformers/modeling_tf_albert.py index 698a1523d6..78b0077d8e 100644 --- a/src/transformers/modeling_tf_albert.py +++ b/src/transformers/modeling_tf_albert.py @@ -480,6 +480,8 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): @keras_serializable class TFAlbertMainLayer(tf.keras.layers.Layer): + config_class = AlbertConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_bert.py b/src/transformers/modeling_tf_bert.py index 30bfdb4347..11ae8da6b8 100644 --- a/src/transformers/modeling_tf_bert.py +++ b/src/transformers/modeling_tf_bert.py @@ -473,6 +473,8 @@ class TFBertNSPHead(tf.keras.layers.Layer): @keras_serializable class TFBertMainLayer(tf.keras.layers.Layer): + config_class = BertConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.num_hidden_layers = config.num_hidden_layers diff --git a/src/transformers/modeling_tf_ctrl.py b/src/transformers/modeling_tf_ctrl.py index d801b7cb50..3c13671f4b 100644 --- a/src/transformers/modeling_tf_ctrl.py +++ b/src/transformers/modeling_tf_ctrl.py @@ -166,6 +166,8 @@ class TFEncoderLayer(tf.keras.layers.Layer): @keras_serializable class TFCTRLMainLayer(tf.keras.layers.Layer): + config_class = CTRLConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_gpt2.py b/src/transformers/modeling_tf_gpt2.py index 7a9a72bd68..18752c34aa 100644 --- a/src/transformers/modeling_tf_gpt2.py +++ b/src/transformers/modeling_tf_gpt2.py @@ -199,6 +199,8 @@ class TFBlock(tf.keras.layers.Layer): @keras_serializable class TFGPT2MainLayer(tf.keras.layers.Layer): + config_class = GPT2Config + def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) self.output_hidden_states = config.output_hidden_states diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 612e7a711d..53c4dc7bef 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -380,6 +380,8 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): @keras_serializable class TFTransfoXLMainLayer(tf.keras.layers.Layer): + config_class = TransfoXLConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 3c94c6bf60..122ee9c911 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -50,11 +50,13 @@ class TFModelUtilsMixin: def keras_serializable(cls): initializer = cls.__init__ + config_class = getattr(cls, "config_class", None) + if config_class is None: + raise AttributeError("Must set `config_class` to use @keras_serializable") + def wrapped_init(self, config, *args, **kwargs): if isinstance(config, dict): - from transformers import AutoConfig - - config = AutoConfig.config_class_for_model_class(cls).from_dict(config) + config = config_class.from_dict(config) initializer(self, config, *args, **kwargs) self._transformers_config = config diff --git a/src/transformers/modeling_tf_xlnet.py b/src/transformers/modeling_tf_xlnet.py index c899628d88..8797a22194 100644 --- a/src/transformers/modeling_tf_xlnet.py +++ b/src/transformers/modeling_tf_xlnet.py @@ -351,6 +351,8 @@ class TFXLNetLMHead(tf.keras.layers.Layer): @keras_serializable class TFXLNetMainLayer(tf.keras.layers.Layer): + config_class = XLNetConfig + def __init__(self, config, **kwargs): super().__init__(**kwargs) self.output_attentions = config.output_attentions