from_pretrained: check that the pretrained model is for the right model architecture (#10586)

* Added check to ensure model name passed to from_pretrained and model are the same

* Added test to check from_pretrained throws assert error when passed an incompatiable model name

* Modified assert in from_pretrained with f-strings. Modified test to ensure desired assert message is being generated

* Added check to ensure config and model has model_type

* Fix FlauBERT heads

Co-authored-by: vimarsh chaturvedi <vimarsh chaturvedi>
Co-authored-by: Stas Bekman <stas@stason.org>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Vimarsh Chaturvedi
2021-03-18 22:21:42 +05:30
committed by GitHub
parent 4f3e93cfaf
commit 094afa515d
3 changed files with 21 additions and 0 deletions

View File

@@ -47,6 +47,7 @@ if is_torch_available():
BertModel,
PretrainedConfig,
PreTrainedModel,
T5ForConditionalGeneration,
)
@@ -58,6 +59,9 @@ def _config_zero_init(config):
return configs_no_init
TINY_T5 = "patrickvonplaten/t5-tiny-random"
@require_torch
class ModelTesterMixin:
@@ -1284,3 +1288,11 @@ class ModelUtilsTest(unittest.TestCase):
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
def test_model_from_pretrained_with_different_pretrained_model_name(self):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model)
with self.assertRaises(Exception) as context:
BertModel.from_pretrained(TINY_T5)
self.assertTrue("You tried to initiate a model of type" in str(context.exception))