From f34372a9ff99f6bc8619ac83dc07f7afe6b92141 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 22 Oct 2020 15:39:01 +0200 Subject: [PATCH] [PretrainedConfig] Fix save pretrained config for edge case (#7943) * fix config save * add test * add config class variable and another test * line break * fix fsmt and typo * god am I making many errors today :-/ * Update src/transformers/configuration_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/configuration_encoder_decoder.py | 1 + src/transformers/configuration_fsmt.py | 6 +++--- src/transformers/configuration_rag.py | 1 + src/transformers/configuration_utils.py | 14 +++++++++++++- tests/test_configuration_common.py | 7 +++++++ tests/test_modeling_prophetnet.py | 9 +++++++++ 6 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/transformers/configuration_encoder_decoder.py b/src/transformers/configuration_encoder_decoder.py index eff92bf245..e357d15a06 100644 --- a/src/transformers/configuration_encoder_decoder.py +++ b/src/transformers/configuration_encoder_decoder.py @@ -70,6 +70,7 @@ class EncoderDecoderConfig(PretrainedConfig): >>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) """ model_type = "encoder_decoder" + is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/configuration_fsmt.py b/src/transformers/configuration_fsmt.py index 747f47dd52..b20328bc43 100644 --- a/src/transformers/configuration_fsmt.py +++ b/src/transformers/configuration_fsmt.py @@ -126,9 +126,9 @@ class FSMTConfig(PretrainedConfig): # update the defaults from config file def __init__( self, - langs, - src_vocab_size, - tgt_vocab_size, + langs=["en", "de"], + src_vocab_size=42024, + tgt_vocab_size=42024, activation_function="relu", d_model=1024, max_length=200, diff --git a/src/transformers/configuration_rag.py b/src/transformers/configuration_rag.py index 30baca04c5..c18e1980b4 100644 --- a/src/transformers/configuration_rag.py +++ b/src/transformers/configuration_rag.py @@ -77,6 +77,7 @@ RAG_CONFIG_DOC = r""" @add_start_docstrings(RAG_CONFIG_DOC) class RagConfig(PretrainedConfig): model_type = "rag" + is_composition = True def __init__( self, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c4044aece5..57f635bfbc 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -41,6 +41,10 @@ class PretrainedConfig(object): Class attributes (overridden by derived classes) - **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to recreate the correct object in :class:`~transformers.AutoConfig`. + - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple + sub-configs. In this case the config has to be initialized from two or more configs of + type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or + :class:`~RagConfig`. Args: name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`): @@ -145,6 +149,7 @@ class PretrainedConfig(object): use BFloat16 scalars (only used by some TensorFlow models). """ model_type: str = "" + is_composition: bool = False def __init__(self, **kwargs): # Attributes with defaults @@ -476,11 +481,18 @@ class PretrainedConfig(object): # get the default config dict default_config_dict = PretrainedConfig().to_dict() + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + serializable_config_dict = {} # only serialize values that differ from the default config for key, value in config_dict.items(): - if key not in default_config_dict or value != default_config_dict[key]: + if ( + key not in default_config_dict + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): serializable_config_dict[key] = value return serializable_config_dict diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 7498ae6caf..53dbc9eeb9 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -66,9 +66,16 @@ class ConfigTester(object): self.parent.assertEqual(len(config.id2label), 3) self.parent.assertEqual(len(config.label2id), 3) + def check_config_can_be_init_without_params(self): + if self.config_class.is_composition: + return + config = self.config_class() + self.parent.assertIsNotNone(config) + def run_common_tests(self): self.create_and_test_config_common_properties() self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_file() self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_with_num_labels() + self.check_config_can_be_init_without_params() diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index 90ca042db8..55336c5d2f 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -901,6 +901,15 @@ class ProphetNetModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_model_with_attn_mask(*config_and_inputs) + def test_config_save(self): + config = self.model_tester.prepare_config_and_inputs()[0] + config.add_cross_attention = False + with tempfile.TemporaryDirectory() as tmp_dirname: + config.save_pretrained(tmp_dirname) + config = ProphetNetConfig.from_pretrained(tmp_dirname) + + self.assertFalse(config.add_cross_attention) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs()