Fix @require_read_token in tests (#29367)
This commit is contained in:
@@ -38,7 +38,6 @@ from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import huggingface_hub
|
|
||||||
import urllib3
|
import urllib3
|
||||||
|
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
@@ -466,11 +465,11 @@ def require_read_token(fn):
|
|||||||
"""
|
"""
|
||||||
A decorator that loads the HF token for tests that require to load gated models.
|
A decorator that loads the HF token for tests that require to load gated models.
|
||||||
"""
|
"""
|
||||||
token = os.getenv("HF_HUB_READ_TOKEN", None)
|
token = os.getenv("HF_HUB_READ_TOKEN")
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def _inner(*args, **kwargs):
|
def _inner(*args, **kwargs):
|
||||||
with patch.object(huggingface_hub.utils._headers, "get_token", return_value=token):
|
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
return _inner
|
return _inner
|
||||||
|
|||||||
Reference in New Issue
Block a user