Use commit hash to look in cache instead of calling head (#18534)

* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Address Julien's comments

Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
Sylvain Gugger
2022-08-10 11:55:18 -04:00
committed by GitHub
parent 6eb51450fa
commit 0d0aada564
15 changed files with 221 additions and 23 deletions

View File

@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("test-config", use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
base_config = PretrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(