update require_read_token (#38093)

* update require_read_token

* new repo

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-05-13 12:07:07 +02:00
committed by GitHub
parent e3b70b0d1c
commit 3ad35d0bca
2 changed files with 32 additions and 11 deletions

View File

@@ -31,6 +31,7 @@ import sys
import tempfile import tempfile
import threading import threading
import time import time
import types
import unittest import unittest
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Generator, Iterable, Iterator, Mapping from collections.abc import Generator, Iterable, Iterator, Mapping
@@ -560,21 +561,39 @@ def require_torch_sdpa(test_case):
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
def require_read_token(fn): def require_read_token(test_case):
""" """
A decorator that loads the HF token for tests that require to load gated models. A decorator that loads the HF token for tests that require to load gated models.
""" """
token = os.getenv("HF_HUB_READ_TOKEN") token = os.getenv("HF_HUB_READ_TOKEN")
@wraps(fn) if isinstance(test_case, type):
def _inner(*args, **kwargs): for attr_name in dir(test_case):
if token is not None: attr = getattr(test_case, attr_name)
with patch("huggingface_hub.utils._headers.get_token", return_value=token): if isinstance(attr, types.FunctionType):
return fn(*args, **kwargs) if getattr(attr, "__require_read_token__", False):
else: # Allow running locally with the default token env variable continue
return fn(*args, **kwargs) wrapped = require_read_token(attr)
setattr(test_case, attr_name, wrapped)
return test_case
else:
if getattr(test_case, "__require_read_token__", False):
return test_case
return _inner @functools.wraps(test_case)
def wrapper(*args, **kwargs):
if token is not None:
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return test_case(*args, **kwargs)
else: # Allow running locally with the default token env variable
# dealing with static/class methods and called by `self.xxx`
if "staticmethod" in inspect.getsource(test_case).strip():
if len(args) > 0 and isinstance(args[0], unittest.TestCase):
return test_case(*args[1:], **kwargs)
return test_case(*args, **kwargs)
wrapper.__require_read_token__ = True
return wrapper
def require_peft(test_case): def require_peft(test_case):

View File

@@ -51,10 +51,12 @@ class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_std=[0.229, 0.224, 0.225], image_std=[0.229, 0.224, 0.225],
do_convert_rgb=True, do_convert_rgb=True,
) )
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-vision-8b", padding_side="left") tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b", padding_side="left"
)
processor_kwargs = cls.prepare_processor_dict() processor_kwargs = cls.prepare_processor_dict()
processor = AyaVisionProcessor.from_pretrained( processor = AyaVisionProcessor.from_pretrained(
"CohereForAI/aya-vision-8b", "hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b",
image_processor=image_processor, image_processor=image_processor,
tokenizer=tokenizer, tokenizer=tokenizer,
**processor_kwargs, **processor_kwargs,