Fix RequestCounter to make it more future-proof (#27406)
* Fix RequestCounter to make it more future-proof * code quality
This commit is contained in:
@@ -29,14 +29,15 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import huggingface_hub
|
import urllib3
|
||||||
import requests
|
|
||||||
|
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
|
|||||||
class RequestCounter:
|
class RequestCounter:
|
||||||
"""
|
"""
|
||||||
Helper class that will count all requests made online.
|
Helper class that will count all requests made online.
|
||||||
|
|
||||||
|
Might not be robust if urllib3 changes its logging format but should be good enough for us.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
```py
|
||||||
|
with RequestCounter() as counter:
|
||||||
|
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
assert counter["GET"] == 0
|
||||||
|
assert counter["HEAD"] == 1
|
||||||
|
assert counter.total_calls == 1
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.head_request_count = 0
|
self._counter = defaultdict(int)
|
||||||
self.get_request_count = 0
|
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
|
||||||
self.other_request_count = 0
|
self.mock = self.patcher.start()
|
||||||
|
|
||||||
# Mock `get_session` to count HTTP calls.
|
|
||||||
self.old_get_session = huggingface_hub.utils._http.get_session
|
|
||||||
self.session = requests.Session()
|
|
||||||
self.session.request = self.new_request
|
|
||||||
huggingface_hub.utils._http.get_session = lambda: self.session
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
def __exit__(self, *args, **kwargs) -> None:
|
||||||
huggingface_hub.utils._http.get_session = self.old_get_session
|
for call in self.mock.call_args_list:
|
||||||
|
log = call.args[0] % call.args[1:]
|
||||||
|
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
|
||||||
|
if method in log:
|
||||||
|
self._counter[method] += 1
|
||||||
|
break
|
||||||
|
self.patcher.stop()
|
||||||
|
|
||||||
def new_request(self, method, **kwargs):
|
def __getitem__(self, key: str) -> int:
|
||||||
if method == "GET":
|
return self._counter[key]
|
||||||
self.get_request_count += 1
|
|
||||||
elif method == "HEAD":
|
|
||||||
self.head_request_count += 1
|
|
||||||
else:
|
|
||||||
self.other_request_count += 1
|
|
||||||
|
|
||||||
return requests.request(method=method, **kwargs)
|
@property
|
||||||
|
def total_calls(self) -> int:
|
||||||
|
return sum(self._counter.values())
|
||||||
|
|
||||||
|
|
||||||
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
|
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
|
||||||
|
|||||||
@@ -482,25 +482,22 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
with self.assertRaisesRegex(EnvironmentError, "Use `from_flax=True` to load this model"):
|
||||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
|
|
||||||
)
|
|
||||||
def test_cached_model_has_minimum_calls_to_head(self):
|
def test_cached_model_has_minimum_calls_to_head(self):
|
||||||
# Make sure we have cached the model.
|
# Make sure we have cached the model.
|
||||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
with RequestCounter() as counter:
|
with RequestCounter() as counter:
|
||||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter["GET"], 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter["HEAD"], 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.total_calls, 1)
|
||||||
|
|
||||||
# With a sharded checkpoint
|
# With a sharded checkpoint
|
||||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
with RequestCounter() as counter:
|
with RequestCounter() as counter:
|
||||||
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter["GET"], 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter["HEAD"], 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.total_calls, 1)
|
||||||
|
|
||||||
def test_attr_not_existing(self):
|
def test_attr_not_existing(self):
|
||||||
from transformers.models.auto.auto_factory import _LazyAutoMapping
|
from transformers.models.auto.auto_factory import _LazyAutoMapping
|
||||||
|
|||||||
@@ -301,14 +301,14 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
with RequestCounter() as counter:
|
with RequestCounter() as counter:
|
||||||
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = TFAutoModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter["GET"], 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter["HEAD"], 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.total_calls, 1)
|
||||||
|
|
||||||
# With a sharded checkpoint
|
# With a sharded checkpoint
|
||||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
with RequestCounter() as counter:
|
with RequestCounter() as counter:
|
||||||
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter["GET"], 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter["HEAD"], 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.total_calls, 1)
|
||||||
|
|||||||
@@ -419,14 +419,11 @@ class AutoTokenizerTest(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
_ = AutoTokenizer.from_pretrained(DUMMY_UNKNOWN_IDENTIFIER, revision="aaaaaa")
|
||||||
|
|
||||||
@unittest.skip(
|
|
||||||
"Currently failing with new huggingface_hub release. See: https://github.com/huggingface/transformers/pull/27389"
|
|
||||||
)
|
|
||||||
def test_cached_tokenizer_has_minimum_calls_to_head(self):
|
def test_cached_tokenizer_has_minimum_calls_to_head(self):
|
||||||
# Make sure we have cached the tokenizer.
|
# Make sure we have cached the tokenizer.
|
||||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
with RequestCounter() as counter:
|
with RequestCounter() as counter:
|
||||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter["GET"], 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter["HEAD"], 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.total_calls, 1)
|
||||||
|
|||||||
@@ -763,9 +763,9 @@ class CustomPipelineTest(unittest.TestCase):
|
|||||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||||
with RequestCounter() as counter:
|
with RequestCounter() as counter:
|
||||||
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
|
||||||
self.assertEqual(counter.get_request_count, 0)
|
self.assertEqual(counter["GET"], 0)
|
||||||
self.assertEqual(counter.head_request_count, 1)
|
self.assertEqual(counter["HEAD"], 1)
|
||||||
self.assertEqual(counter.other_request_count, 0)
|
self.assertEqual(counter.total_calls, 1)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_chunk_pipeline_batching_single_file(self):
|
def test_chunk_pipeline_batching_single_file(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user