Fix from_pretrained with default base_model_prefix (#15814)

This commit is contained in:
Sylvain Gugger
2022-02-24 11:43:51 +01:00
committed by GitHub
parent 7f921bcf47
commit d1fcc90abf
3 changed files with 12 additions and 7 deletions

View File

@@ -2105,7 +2105,10 @@ class ModelUtilsTest(TestCasePlus):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = NoSuperInitModel.from_pretrained(tmp_dir)
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
@require_torch