update require_read_token (#38093)
* update require_read_token * new repo * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,7 @@ import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Generator, Iterable, Iterator, Mapping
|
||||
@@ -560,21 +561,39 @@ def require_torch_sdpa(test_case):
|
||||
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
|
||||
|
||||
|
||||
def require_read_token(fn):
|
||||
def require_read_token(test_case):
|
||||
"""
|
||||
A decorator that loads the HF token for tests that require to load gated models.
|
||||
"""
|
||||
token = os.getenv("HF_HUB_READ_TOKEN")
|
||||
|
||||
@wraps(fn)
|
||||
def _inner(*args, **kwargs):
|
||||
if token is not None:
|
||||
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
|
||||
return fn(*args, **kwargs)
|
||||
else: # Allow running locally with the default token env variable
|
||||
return fn(*args, **kwargs)
|
||||
if isinstance(test_case, type):
|
||||
for attr_name in dir(test_case):
|
||||
attr = getattr(test_case, attr_name)
|
||||
if isinstance(attr, types.FunctionType):
|
||||
if getattr(attr, "__require_read_token__", False):
|
||||
continue
|
||||
wrapped = require_read_token(attr)
|
||||
setattr(test_case, attr_name, wrapped)
|
||||
return test_case
|
||||
else:
|
||||
if getattr(test_case, "__require_read_token__", False):
|
||||
return test_case
|
||||
|
||||
return _inner
|
||||
@functools.wraps(test_case)
|
||||
def wrapper(*args, **kwargs):
|
||||
if token is not None:
|
||||
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
|
||||
return test_case(*args, **kwargs)
|
||||
else: # Allow running locally with the default token env variable
|
||||
# dealing with static/class methods and called by `self.xxx`
|
||||
if "staticmethod" in inspect.getsource(test_case).strip():
|
||||
if len(args) > 0 and isinstance(args[0], unittest.TestCase):
|
||||
return test_case(*args[1:], **kwargs)
|
||||
return test_case(*args, **kwargs)
|
||||
|
||||
wrapper.__require_read_token__ = True
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
|
||||
@@ -51,10 +51,12 @@ class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
do_convert_rgb=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-vision-8b", padding_side="left")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b", padding_side="left"
|
||||
)
|
||||
processor_kwargs = cls.prepare_processor_dict()
|
||||
processor = AyaVisionProcessor.from_pretrained(
|
||||
"CohereForAI/aya-vision-8b",
|
||||
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b",
|
||||
image_processor=image_processor,
|
||||
tokenizer=tokenizer,
|
||||
**processor_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user