Use commit hash to look in cache instead of calling head (#18534)

* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Address Julien's comments

Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
Sylvain Gugger
2022-08-10 11:55:18 -04:00
committed by GitHub
parent 6eb51450fa
commit 0d0aada564
15 changed files with 221 additions and 23 deletions

View File

@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_scatter,
require_torch,
slow,
@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase):
def test_model_from_flax_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
# With a sharded checkpoint
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)

View File

@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_tensorflow_probability,
require_tf,
slow,
@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase):
def test_model_from_pt_suggestion(self):
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
def test_cached_model_has_minimum_calls_to_head(self):
# Make sure we have cached the model.
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
# With a sharded checkpoint
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)

View File

@@ -48,6 +48,7 @@ from transformers.testing_utils import (
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
RequestCounter,
require_tokenizers,
slow,
)
@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase):
def test_get_tokenizer_config(self):
# Check we can load the tokenizer config of an online model.
config = get_tokenizer_config("bert-base-cased")
_ = config.pop("_commit_hash", None)
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
self.assertEqual(config, {"do_lower_case": False})
@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase):
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
):
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
def test_cached_tokenizer_has_minimum_calls_to_head(self):
# Make sure we have cached the tokenizer.
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)

View File

@@ -49,6 +49,7 @@ from transformers.testing_utils import (
TOKEN,
USER,
CaptureLogger,
RequestCounter,
is_pipeline_test,
is_staging_test,
nested_simplify,
@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase):
[{"label": "LABEL_0", "score": 0.505}],
)
def test_cached_pipeline_has_minimum_calls_to_head(self):
# Make sure we have cached the pipeline.
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0)
@require_torch
@is_staging_test

View File

@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("test-config", use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
)
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
base_config = PretrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
self.assertListEqual(
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
raise ValueError(