from_pretrained: check that the pretrained model is for the right model architecture (#10586)
* Added check to ensure model name passed to from_pretrained and model are the same * Added test to check from_pretrained throws assert error when passed an incompatiable model name * Modified assert in from_pretrained with f-strings. Modified test to ensure desired assert message is being generated * Added check to ensure config and model has model_type * Fix FlauBERT heads Co-authored-by: vimarsh chaturvedi <vimarsh chaturvedi> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
committed by
GitHub
parent
4f3e93cfaf
commit
094afa515d
@@ -384,6 +384,11 @@ class PretrainedConfig(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||||
|
if config_dict.get("model_type", False) and hasattr(cls, "model_type"):
|
||||||
|
assert (
|
||||||
|
config_dict["model_type"] == cls.model_type
|
||||||
|
), f"You tried to initiate a model of type '{cls.model_type}' with a pretrained model of type '{config_dict['model_type']}'"
|
||||||
|
|
||||||
return cls.from_dict(config_dict, **kwargs)
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -932,6 +932,8 @@ class TFFlaubertForQuestionAnsweringSimple(TFXLMForQuestionAnsweringSimple):
|
|||||||
FLAUBERT_START_DOCSTRING,
|
FLAUBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
|
class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
|
||||||
|
config_class = FlaubertConfig
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
@@ -945,6 +947,8 @@ class TFFlaubertForTokenClassification(TFXLMForTokenClassification):
|
|||||||
FLAUBERT_START_DOCSTRING,
|
FLAUBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
|
class TFFlaubertForMultipleChoice(TFXLMForMultipleChoice):
|
||||||
|
config_class = FlaubertConfig
|
||||||
|
|
||||||
def __init__(self, config, *inputs, **kwargs):
|
def __init__(self, config, *inputs, **kwargs):
|
||||||
super().__init__(config, *inputs, **kwargs)
|
super().__init__(config, *inputs, **kwargs)
|
||||||
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
self.transformer = TFFlaubertMainLayer(config, name="transformer")
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ if is_torch_available():
|
|||||||
BertModel,
|
BertModel,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -58,6 +59,9 @@ def _config_zero_init(config):
|
|||||||
return configs_no_init
|
return configs_no_init
|
||||||
|
|
||||||
|
|
||||||
|
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ModelTesterMixin:
|
class ModelTesterMixin:
|
||||||
|
|
||||||
@@ -1284,3 +1288,11 @@ class ModelUtilsTest(unittest.TestCase):
|
|||||||
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||||
self.assertEqual(model.config.output_hidden_states, True)
|
self.assertEqual(model.config.output_hidden_states, True)
|
||||||
self.assertEqual(model.config, config)
|
self.assertEqual(model.config, config)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_with_different_pretrained_model_name(self):
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
with self.assertRaises(Exception) as context:
|
||||||
|
BertModel.from_pretrained(TINY_T5)
|
||||||
|
self.assertTrue("You tried to initiate a model of type" in str(context.exception))
|
||||||
|
|||||||
Reference in New Issue
Block a user