Fix from_pretrained API with config and state_dict (#21542)

This commit is contained in:
Sylvain Gugger
2023-02-09 15:44:02 -05:00
committed by GitHub
parent 1efe9c0b24
commit 2020ac4bd6
2 changed files with 13 additions and 1 deletions

View File

@@ -2749,6 +2749,15 @@ class ModelUtilsTest(TestCasePlus):
BertModel.from_pretrained(TINY_T5)
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
def test_model_from_pretrained_no_checkpoint(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
state_dict = model.state_dict()
new_model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch
def test_model_from_config_torch_dtype(self):
# test that the model can be instantiated with dtype of user's choice - as long as it's a