Further reduce the number of alls to head for cached objects (#18871)

* Further reduce the number of alls to head for cached models/tokenizers/pipelines

* Fix tests

* Address review comments
This commit is contained in:
Sylvain Gugger
2022-09-06 12:34:37 -04:00
committed by GitHub
parent 6678350c01
commit 71ff88fa4f
5 changed files with 36 additions and 10 deletions

View File

@@ -370,6 +370,5 @@ class AutoModelTest(unittest.TestCase):
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)

View File

@@ -303,6 +303,5 @@ class TFAutoModelTest(unittest.TestCase):
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)

View File

@@ -349,6 +349,5 @@ class AutoTokenizerTest(unittest.TestCase):
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)

View File

@@ -884,8 +884,7 @@ class CustomPipelineTest(unittest.TestCase):
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)