Use commit hash to look in cache instead of calling head (#18534)
* Use commit hash to look in cache instead of calling head * Add tests * Add attr for local configs too * Stupid typos * Fix tests * Update src/transformers/utils/hub.py Co-authored-by: Julien Chaumond <julien@huggingface.co> * Address Julien's comments Co-authored-by: Julien Chaumond <julien@huggingface.co>
This commit is contained in:
@@ -24,6 +24,7 @@ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
RequestCounter,
|
||||
require_scatter,
|
||||
require_torch,
|
||||
slow,
|
||||
@@ -354,3 +355,21 @@ class AutoModelTest(unittest.TestCase):
|
||||
def test_model_from_flax_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
def test_cached_model_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the model.
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
# With a sharded checkpoint
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# There is no pytorch_model.bin so we still get one call for this one.
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5C
|
||||
from transformers.testing_utils import (
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
RequestCounter,
|
||||
require_tensorflow_probability,
|
||||
require_tf,
|
||||
slow,
|
||||
@@ -287,3 +288,21 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
def test_model_from_pt_suggestion(self):
|
||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_pt=True` to load this model"):
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
def test_cached_model_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the model.
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
self.assertEqual(counter.head_request_count, 1)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
# With a sharded checkpoint
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
with RequestCounter() as counter:
|
||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# There is no pytorch_model.bin so we still get one call for this one.
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@@ -48,6 +48,7 @@ from transformers.testing_utils import (
|
||||
DUMMY_DIFF_TOKENIZER_IDENTIFIER,
|
||||
DUMMY_UNKNOWN_IDENTIFIER,
|
||||
SMALL_MODEL_IDENTIFIER,
|
||||
RequestCounter,
|
||||
require_tokenizers,
|
||||
slow,
|
||||
)
|
||||
@@ -213,6 +214,7 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
def test_get_tokenizer_config(self):
|
||||
# Check we can load the tokenizer config of an online model.
|
||||
config = get_tokenizer_config("bert-base-cased")
|
||||
_ = config.pop("_commit_hash", None)
|
||||
# If we ever update bert-base-cased tokenizer config, this dict here will need to be updated.
|
||||
self.assertEqual(config, {"do_lower_case": False})
|
||||
|
||||
@@ -340,3 +342,13 @@ class AutoTokenizerTest(unittest.TestCase):
|
||||
EnvironmentError, r"aaaaaa is not a valid git identifier \(branch name, tag name or commit id\)"
|
||||
):
|
||||
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||
|
||||
def test_cached_tokenizer_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the tokenizer.
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# We still have one extra call because the model does not have a added_tokens.json file
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
@@ -49,6 +49,7 @@ from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
RequestCounter,
|
||||
is_pipeline_test,
|
||||
is_staging_test,
|
||||
nested_simplify,
|
||||
@@ -877,6 +878,16 @@ class CustomPipelineTest(unittest.TestCase):
|
||||
[{"label": "LABEL_0", "score": 0.505}],
|
||||
)
|
||||
|
||||
def test_cached_pipeline_has_minimum_calls_to_head(self):
|
||||
# Make sure we have cached the pipeline.
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
with RequestCounter() as counter:
|
||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||
self.assertEqual(counter.get_request_count, 0)
|
||||
# We still have one extra call because the model does not have a added_tokens.json file
|
||||
self.assertEqual(counter.head_request_count, 2)
|
||||
self.assertEqual(counter.other_request_count, 0)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
@@ -246,7 +246,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config.push_to_hub("test-config", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -258,7 +258,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -269,7 +269,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -283,7 +283,7 @@ class ConfigPushToHubTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.__dict__.items():
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
@@ -323,7 +323,9 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
base_config = PretrainedConfig()
|
||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||
self.assertListEqual(missing_keys, ["is_encoder_decoder", "_name_or_path", "transformers_version"])
|
||||
self.assertListEqual(
|
||||
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
|
||||
)
|
||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||
if len(keys_with_defaults) > 0:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user