From 094afa515d78527530fb1960db0b4970747a2031 Mon Sep 17 00:00:00 2001 From: Vimarsh Chaturvedi Date: Thu, 18 Mar 2021 22:21:42 +0530 Subject: [PATCH] from_pretrained: check that the pretrained model is for the right model architecture (#10586) * Added check to ensure model name passed to from_pretrained and model are the same * Added test to check from_pretrained throws assert error when passed an incompatiable model name * Modified assert in from_pretrained with f-strings. Modified test to ensure desired assert message is being generated * Added check to ensure config and model has model_type * Fix FlauBERT heads Co-authored-by: vimarsh chaturvedi Co-authored-by: Stas Bekman Co-authored-by: Lysandre --- src/transformers/configuration_utils.py | 5 +++++ .../models/flaubert/modeling_tf_flaubert.py | 4 ++++ tests/test_modeling_common.py | 12 ++++++++++++ 3 files changed, 21 insertions(+) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 4e5de61386..c6830f5083 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -384,6 +384,11 @@ class PretrainedConfig(object): """ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if config_dict.get("model_type", False) and hasattr(cls, "model_type"): + assert ( + config_dict["model_type"] == cls.model_type + ), f"You tried to initiate a model of type '{cls.model_type}' with a pretrained model of type '{config_dict['model_type']}'" + return cls.from_dict(config_dict, **kwargs) @classmethod diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index b5f8c7b199..646c5da050 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -932,6 +932,8 @@ class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple): FLAUBERT_START_DOCSTRING, ) class TFFlaubertForTokenClassification(TFXLMForTokenClassification): + config_class = FlaubertConfig + def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.transformer = TFFlaubertMainLayer(config, name="transformer") @@ -945,6 +947,8 @@ class TFFlaubertForTokenClassification(TFXLMForTokenClassification): FLAUBERT_START_DOCSTRING, ) class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice): + config_class = FlaubertConfig + def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.transformer = TFFlaubertMainLayer(config, name="transformer") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index afded0b3fe..96f5d505ad 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -47,6 +47,7 @@ if is_torch_available(): BertModel, PretrainedConfig, PreTrainedModel, + T5ForConditionalGeneration, ) @@ -58,6 +59,9 @@ def _config_zero_init(config): return configs_no_init +TINY_T5 = "patrickvonplaten/t5-tiny-random" + + @require_torch class ModelTesterMixin: @@ -1284,3 +1288,11 @@ class ModelUtilsTest(unittest.TestCase): model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config, config) + + def test_model_from_pretrained_with_different_pretrained_model_name(self): + model = T5ForConditionalGeneration.from_pretrained(TINY_T5) + self.assertIsNotNone(model) + + with self.assertRaises(Exception) as context: + BertModel.from_pretrained(TINY_T5) + self.assertTrue("You tried to initiate a model of type" in str(context.exception))