Allow relative imports in dynamic code (#15352)

* Allow dynamic modules to use relative imports

* Add tests

* Add one last test

* Changes
This commit is contained in:
Sylvain Gugger
2022-01-27 14:47:59 -05:00
committed by GitHub
parent 628b59e51d
commit 0b07230409
3 changed files with 174 additions and 55 deletions

View File

@@ -102,3 +102,7 @@ class AutoConfigTest(unittest.TestCase):
"hf-internal-testing/no-config-test-repo does not appear to have a file named config.json.",
):
_ = AutoConfig.from_pretrained("hf-internal-testing/no-config-test-repo")
def test_from_pretrained_dynamic_config(self):
config = AutoConfig.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(config.__class__.__name__, "NewModelConfig")

View File

@@ -324,7 +324,7 @@ class AutoModelTest(unittest.TestCase):
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
def test_from_pretrained_dynamic_model(self):
def test_from_pretrained_dynamic_model_local(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
@@ -340,6 +340,14 @@ class AutoModelTest(unittest.TestCase):
for p1, p2 in zip(model.parameters(), new_model.parameters()):
self.assertTrue(torch.equal(p1, p2))
def test_from_pretrained_dynamic_model_distant(self):
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
# This one uses a relative import to a util file, this checks it is downloaded and used properly.
model = AutoModel.from_pretrained("hf-internal-testing/test_dynamic_model_with_util", trust_remote_code=True)
self.assertEqual(model.__class__.__name__, "NewModel")
def test_new_model_registration(self):
AutoConfig.register("new-model", NewModelConfig)