From 3ad35d0bcaf5520f2b25441a50f205300a07f3d0 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 13 May 2025 12:07:07 +0200 Subject: [PATCH] update `require_read_token` (#38093) * update require_read_token * new repo * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/testing_utils.py | 37 ++++++++++++++----- .../aya_vision/test_processor_aya_vision.py | 6 ++- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 0a9bf1b06c..5ab348377d 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -31,6 +31,7 @@ import sys import tempfile import threading import time +import types import unittest from collections import UserDict, defaultdict from collections.abc import Generator, Iterable, Iterator, Mapping @@ -560,21 +561,39 @@ def require_torch_sdpa(test_case): return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) -def require_read_token(fn): +def require_read_token(test_case): """ A decorator that loads the HF token for tests that require to load gated models. """ token = os.getenv("HF_HUB_READ_TOKEN") - @wraps(fn) - def _inner(*args, **kwargs): - if token is not None: - with patch("huggingface_hub.utils._headers.get_token", return_value=token): - return fn(*args, **kwargs) - else: # Allow running locally with the default token env variable - return fn(*args, **kwargs) + if isinstance(test_case, type): + for attr_name in dir(test_case): + attr = getattr(test_case, attr_name) + if isinstance(attr, types.FunctionType): + if getattr(attr, "__require_read_token__", False): + continue + wrapped = require_read_token(attr) + setattr(test_case, attr_name, wrapped) + return test_case + else: + if getattr(test_case, "__require_read_token__", False): + return test_case - return _inner + @functools.wraps(test_case) + def wrapper(*args, **kwargs): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return test_case(*args, **kwargs) + else: # Allow running locally with the default token env variable + # dealing with static/class methods and called by `self.xxx` + if "staticmethod" in inspect.getsource(test_case).strip(): + if len(args) > 0 and isinstance(args[0], unittest.TestCase): + return test_case(*args[1:], **kwargs) + return test_case(*args, **kwargs) + + wrapper.__require_read_token__ = True + return wrapper def require_peft(test_case): diff --git a/tests/models/aya_vision/test_processor_aya_vision.py b/tests/models/aya_vision/test_processor_aya_vision.py index 9af13eab32..e0983d489e 100644 --- a/tests/models/aya_vision/test_processor_aya_vision.py +++ b/tests/models/aya_vision/test_processor_aya_vision.py @@ -51,10 +51,12 @@ class AyaVisionProcessorTest(ProcessorTesterMixin, unittest.TestCase): image_std=[0.229, 0.224, 0.225], do_convert_rgb=True, ) - tokenizer = AutoTokenizer.from_pretrained("CohereForAI/aya-vision-8b", padding_side="left") + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b", padding_side="left" + ) processor_kwargs = cls.prepare_processor_dict() processor = AyaVisionProcessor.from_pretrained( - "CohereForAI/aya-vision-8b", + "hf-internal-testing/namespace-CohereForAI-repo_name_aya-vision-8b", image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs,