update require_read_token (#38093)

* update require_read_token

* new repo

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2025-05-13 12:07:07 +02:00
committed by GitHub
parent e3b70b0d1c
commit 3ad35d0bca
2 changed files with 32 additions and 11 deletions

View File

@@ -31,6 +31,7 @@ import sys
import tempfile
import threading
import time
import types
import unittest
from collections import UserDict, defaultdict
from collections.abc import Generator, Iterable, Iterator, Mapping
@@ -560,21 +561,39 @@ def require_torch_sdpa(test_case):
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
def require_read_token(fn):
def require_read_token(test_case):
"""
A decorator that loads the HF token for tests that require to load gated models.
"""
token = os.getenv("HF_HUB_READ_TOKEN")
@wraps(fn)
def _inner(*args, **kwargs):
if token is not None:
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return fn(*args, **kwargs)
else: # Allow running locally with the default token env variable
return fn(*args, **kwargs)
if isinstance(test_case, type):
for attr_name in dir(test_case):
attr = getattr(test_case, attr_name)
if isinstance(attr, types.FunctionType):
if getattr(attr, "__require_read_token__", False):
continue
wrapped = require_read_token(attr)
setattr(test_case, attr_name, wrapped)
return test_case
else:
if getattr(test_case, "__require_read_token__", False):
return test_case
return _inner
@functools.wraps(test_case)
def wrapper(*args, **kwargs):
if token is not None:
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return test_case(*args, **kwargs)
else: # Allow running locally with the default token env variable
# dealing with static/class methods and called by `self.xxx`
if "staticmethod" in inspect.getsource(test_case).strip():
if len(args) > 0 and isinstance(args[0], unittest.TestCase):
return test_case(*args[1:], **kwargs)
return test_case(*args, **kwargs)
wrapper.__require_read_token__ = True
return wrapper
def require_peft(test_case):

View File

@@ -51,10 +51,12 @@ class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
image_std=[0.229, 0.224, 0.225],
do_convert_rgb=True,
)
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-vision-8b", padding_side="left")
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b", padding_side="left"
)
processor_kwargs = cls.prepare_processor_dict()
processor = AyaVisionProcessor.from_pretrained(
"CohereForAI/aya-vision-8b",
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b",
image_processor=image_processor,
tokenizer=tokenizer,
**processor_kwargs,