Use code on the Hub from another repo (#22814)

* initial work

* Add other classes

* Refactor code

* Move warning and fix dynamic pipeline

* Issue warning when necessary

* Add test

* Do not skip auto tests

* Fix failing tests

* Refactor and address review comments

* Address review comments
This commit is contained in:
Sylvain Gugger
2023-04-18 13:46:11 -04:00
committed by GitHub
parent aec10d162f
commit 5f9b825c89
15 changed files with 124 additions and 66 deletions

View File

@@ -298,6 +298,34 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_from_pretrained_dynamic_model_distant_with_ref(self):
model = AutoModel.from_pretrained("hf-internal-testing/ref_to_test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
# Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
model = AutoModel.from_pretrained(
"hf-internal-testing/ref_to_test_dynamic_model_with_util", trust_remote_code=True
)
self.assertEqual(model.__class__.__name__, "NewModel")
# Test model can be reloaded.
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
reloaded_model = AutoModel.from_pretrained(tmp_dir, trust_remote_code=True)
self.assertEqual(reloaded_model.__class__.__name__, "NewModel")
for p1, p2 in zip(model.parameters(), reloaded_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_new_model_registration(self):
AutoConfig.register("custom", CustomConfig)