Fix RequestCounter to make it more future-proof (#27406)
* Fix RequestCounter to make it more future-proof * code quality
This commit is contained in:
@@ -29,14 +29,15 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import huggingface_hub
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
|
||||
class RequestCounter:
|
||||
"""
|
||||
Helper class that will count all requests made online.
|
||||
|
||||
Might not be robust if urllib3 changes its logging format but should be good enough for us.
|
||||
|
||||
Usage:
|
||||
```py
|
||||
with RequestCounter() as counter:
|
||||
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
assert counter["GET"] == 0
|
||||
assert counter["HEAD"] == 1
|
||||
assert counter.total_calls == 1
|
||||
```
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.head_request_count = 0
|
||||
self.get_request_count = 0
|
||||
self.other_request_count = 0
|
||||
|
||||
# Mock `get_session` to count HTTP calls.
|
||||
self.old_get_session = huggingface_hub.utils._http.get_session
|
||||
self.session = requests.Session()
|
||||
self.session.request = self.new_request
|
||||
huggingface_hub.utils._http.get_session = lambda: self.session
|
||||
self._counter = defaultdict(int)
|
||||
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
|
||||
self.mock = self.patcher.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
huggingface_hub.utils._http.get_session = self.old_get_session
|
||||
def __exit__(self, *args, **kwargs) -> None:
|
||||
for call in self.mock.call_args_list:
|
||||
log = call.args[0] % call.args[1:]
|
||||
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
|
||||
if method in log:
|
||||
self._counter[method] += 1
|
||||
break
|
||||
self.patcher.stop()
|
||||
|
||||
def new_request(self, method, **kwargs):
|
||||
if method == "GET":
|
||||
self.get_request_count += 1
|
||||
elif method == "HEAD":
|
||||
self.head_request_count += 1
|
||||
else:
|
||||
self.other_request_count += 1
|
||||
def __getitem__(self, key: str) -> int:
|
||||
return self._counter[key]
|
||||
|
||||
return requests.request(method=method, **kwargs)
|
||||
@property
|
||||
def total_calls(self) -> int:
|
||||
return sum(self._counter.values())
|
||||
|
||||
|
||||
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
|
||||
|
||||
Reference in New Issue
Block a user