update require_read_token (#38093)
* update require_read_token * new repo * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,7 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Generator, Iterable, Iterator, Mapping
|
from collections.abc import Generator, Iterable, Iterator, Mapping
|
||||||
@@ -560,21 +561,39 @@ def require_torch_sdpa(test_case):
|
|||||||
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
|
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_read_token(fn):
|
def require_read_token(test_case):
|
||||||
"""
|
"""
|
||||||
A decorator that loads the HF token for tests that require to load gated models.
|
A decorator that loads the HF token for tests that require to load gated models.
|
||||||
"""
|
"""
|
||||||
token = os.getenv("HF_HUB_READ_TOKEN")
|
token = os.getenv("HF_HUB_READ_TOKEN")
|
||||||
|
|
||||||
@wraps(fn)
|
if isinstance(test_case, type):
|
||||||
def _inner(*args, **kwargs):
|
for attr_name in dir(test_case):
|
||||||
|
attr = getattr(test_case, attr_name)
|
||||||
|
if isinstance(attr, types.FunctionType):
|
||||||
|
if getattr(attr, "__require_read_token__", False):
|
||||||
|
continue
|
||||||
|
wrapped = require_read_token(attr)
|
||||||
|
setattr(test_case, attr_name, wrapped)
|
||||||
|
return test_case
|
||||||
|
else:
|
||||||
|
if getattr(test_case, "__require_read_token__", False):
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
@functools.wraps(test_case)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
if token is not None:
|
if token is not None:
|
||||||
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
|
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
|
||||||
return fn(*args, **kwargs)
|
return test_case(*args, **kwargs)
|
||||||
else: # Allow running locally with the default token env variable
|
else: # Allow running locally with the default token env variable
|
||||||
return fn(*args, **kwargs)
|
# dealing with static/class methods and called by `self.xxx`
|
||||||
|
if "staticmethod" in inspect.getsource(test_case).strip():
|
||||||
|
if len(args) > 0 and isinstance(args[0], unittest.TestCase):
|
||||||
|
return test_case(*args[1:], **kwargs)
|
||||||
|
return test_case(*args, **kwargs)
|
||||||
|
|
||||||
return _inner
|
wrapper.__require_read_token__ = True
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_peft(test_case):
|
def require_peft(test_case):
|
||||||
|
|||||||
@@ -51,10 +51,12 @@ class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
|||||||
image_std=[0.229, 0.224, 0.225],
|
image_std=[0.229, 0.224, 0.225],
|
||||||
do_convert_rgb=True,
|
do_convert_rgb=True,
|
||||||
)
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-vision-8b", padding_side="left")
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b", padding_side="left"
|
||||||
|
)
|
||||||
processor_kwargs = cls.prepare_processor_dict()
|
processor_kwargs = cls.prepare_processor_dict()
|
||||||
processor = AyaVisionProcessor.from_pretrained(
|
processor = AyaVisionProcessor.from_pretrained(
|
||||||
"CohereForAI/aya-vision-8b",
|
"hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b",
|
||||||
image_processor=image_processor,
|
image_processor=image_processor,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
**processor_kwargs,
|
**processor_kwargs,
|
||||||
|
|||||||
Reference in New Issue
Block a user