From 81009b7a5c5cb183a9275c15bf347bdc988b02c4 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 13 Apr 2021 10:33:52 -0400 Subject: [PATCH] Replace error by warning when loading an architecture in another (#11207) * Replace error by warning when loading an architecture in another * Style * Style again * Add a test * Adapt old test --- src/transformers/configuration_utils.py | 9 +++++---- tests/test_modeling_bert_generation.py | 13 ++++++------- tests/test_modeling_common.py | 9 +++++---- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index ad517ba154..2b08d10b24 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -399,10 +399,11 @@ class PretrainedConfig(object): """ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - if config_dict.get("model_type", False) and hasattr(cls, "model_type"): - assert ( - config_dict["model_type"] == cls.model_type - ), f"You tried to initiate a model of type '{cls.model_type}' with a pretrained model of type '{config_dict['model_type']}'" + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warn( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) return cls.from_dict(config_dict, **kwargs) diff --git a/tests/test_modeling_bert_generation.py b/tests/test_modeling_bert_generation.py index 2048c127e9..0ca0d81f40 100755 --- a/tests/test_modeling_bert_generation.py +++ b/tests/test_modeling_bert_generation.py @@ -231,13 +231,7 @@ class BertGenerationEncoderTester: self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - ( - config, - input_ids, - input_mask, - token_labels, - ) = config_and_inputs + config, input_ids, input_mask, token_labels = self.prepare_config_and_inputs() inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} return config, inputs_dict @@ -259,6 +253,11 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_as_bert(self): + config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs() + config.model_type = "bert" + self.model_tester.create_and_check_model(config, input_ids, input_mask, token_labels) + def test_model_as_decoder(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() self.model_tester.create_and_check_model_as_decoder(*config_and_inputs) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d5d76162bc..419b92e280 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,10 +22,10 @@ import tempfile import unittest from typing import List, Tuple -from transformers import is_torch_available +from transformers import is_torch_available, logging from transformers.file_utils import WEIGHTS_NAME from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device +from transformers.testing_utils import CaptureLogger, require_torch, require_torch_multi_gpu, slow, torch_device if is_torch_available(): @@ -1296,6 +1296,7 @@ class ModelUtilsTest(unittest.TestCase): model = T5ForConditionalGeneration.from_pretrained(TINY_T5) self.assertIsNotNone(model) - with self.assertRaises(Exception) as context: + logger = logging.get_logger("transformers.configuration_utils") + with CaptureLogger(logger) as cl: BertModel.from_pretrained(TINY_T5) - self.assertTrue("You tried to initiate a model of type" in str(context.exception)) + self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)