Fix RequestCounter to make it more future-proof (#27406)

* Fix RequestCounter to make it more future-proof

* code quality
This commit is contained in:
Lucain
2023-11-09 18:53:26 +01:00
committed by GitHub
parent c8b6052ff6
commit e38348ae8f
5 changed files with 48 additions and 45 deletions

View File

@@ -29,14 +29,15 @@ import sys
import tempfile
import time
import unittest
from collections import defaultdict
from collections.abc import Mapping
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch
import huggingface_hub
import requests
import urllib3
from transformers import logging as transformers_logging
@@ -1983,32 +1984,40 @@ def run_command(command: List[str], return_stdout=False):
class RequestCounter:
"""
Helper class that will count all requests made online.
Might not be robust if urllib3 changes its logging format but should be good enough for us.
Usage:
```py
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
assert counter["GET"] == 0
assert counter["HEAD"] == 1
assert counter.total_calls == 1
```
"""
def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0
# Mock `get_session` to count HTTP calls.
self.old_get_session = huggingface_hub.utils._http.get_session
self.session = requests.Session()
self.session.request = self.new_request
huggingface_hub.utils._http.get_session = lambda: self.session
self._counter = defaultdict(int)
self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
self.mock = self.patcher.start()
return self
def __exit__(self, *args, **kwargs):
huggingface_hub.utils._http.get_session = self.old_get_session
def __exit__(self, *args, **kwargs) -> None:
for call in self.mock.call_args_list:
log = call.args[0] % call.args[1:]
for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
if method in log:
self._counter[method] += 1
break
self.patcher.stop()
def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1
def __getitem__(self, key: str) -> int:
return self._counter[key]
return requests.request(method=method, **kwargs)
@property
def total_calls(self) -> int:
return sum(self._counter.values())
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):