From d4ba6e1a0e8f662f3deadba25d982c6fb5fb772c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 14 Feb 2023 10:57:28 -0500 Subject: [PATCH] Fix generation config for empty state dict (#21630) --- src/transformers/modeling_utils.py | 2 +- tests/test_modeling_common.py | 21 ++++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index db3b9c28f7..fa28feef3c 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2648,7 +2648,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _from_pipeline=from_pipeline, **kwargs, ) - except OSError: + except (OSError, TypeError): logger.info( "Generation config file not found, using a generation config created from the model config." ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 7cb5c4478c..152ea7d6cd 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -325,6 +325,18 @@ class ModelTesterMixin: else: check_save_load(first, second) + def test_from_pretrained_no_checkpoint(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + state_dict = model.state_dict() + + new_model = model_class.from_pretrained( + pretrained_model_name_or_path=None, config=config, state_dict=state_dict + ) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + def test_save_load_keys_to_ignore_on_save(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -2776,15 +2788,6 @@ class ModelUtilsTest(TestCasePlus): BertModel.from_pretrained(TINY_T5) self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out) - def test_model_from_pretrained_no_checkpoint(self): - config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") - model = BertModel(config) - state_dict = model.state_dict() - - new_model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict) - for p1, p2 in zip(model.parameters(), new_model.parameters()): - self.assertTrue(torch.equal(p1, p2)) - def test_model_from_config_torch_dtype(self): # test that the model can be instantiated with dtype of user's choice - as long as it's a # float dtype. To make it happen config.torch_dtype needs to be set before instantiating the