From ad00c482c7fe9437c93bbc6be5a4a428c3219b5c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 28 Feb 2024 06:25:23 +0100 Subject: [PATCH] FIX [`Gemma` / `CI`] Make sure our runners have access to the model (#29242) * pu hf token in gemma tests * update suggestion * add to flax * revert * fix * fixup * forward contrib credits from discussion --------- Co-authored-by: ArthurZucker --- src/transformers/testing_utils.py | 16 ++++++++++++++++ tests/models/gemma/test_modeling_flax_gemma.py | 5 ++--- tests/models/gemma/test_modeling_gemma.py | 3 ++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index ca4b0db8b8..e1415a4cc6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -31,12 +31,14 @@ import time import unittest from collections import defaultdict from collections.abc import Mapping +from functools import wraps from io import StringIO from pathlib import Path from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union from unittest import mock from unittest.mock import patch +import huggingface_hub import urllib3 from transformers import logging as transformers_logging @@ -460,6 +462,20 @@ def require_torch_sdpa(test_case): return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) +def require_read_token(fn): + """ + A decorator that loads the HF token for tests that require to load gated models. + """ + token = os.getenv("HF_HUB_READ_TOKEN", None) + + @wraps(fn) + def _inner(*args, **kwargs): + with patch(huggingface_hub.utils._headers, "get_token", return_value=token): + return fn(*args, **kwargs) + + return _inner + + def require_peft(test_case): """ Decorator marking a test that requires PEFT. diff --git a/tests/models/gemma/test_modeling_flax_gemma.py b/tests/models/gemma/test_modeling_flax_gemma.py index 515ec1837d..0f3c5df4f1 100644 --- a/tests/models/gemma/test_modeling_flax_gemma.py +++ b/tests/models/gemma/test_modeling_flax_gemma.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import unittest import numpy as np from transformers import AutoTokenizer, GemmaConfig, is_flax_available -from transformers.testing_utils import require_flax, slow +from transformers.testing_utils import require_flax, require_read_token, slow from ...generation.test_flax_utils import FlaxGenerationTesterMixin from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor @@ -205,6 +203,7 @@ class FlaxGemmaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitte @slow @require_flax +@require_read_token class FlaxGemmaIntegrationTest(unittest.TestCase): input_text = ["The capital of France is", "To play the perfect cover drive"] model_id = "google/gemma-2b" diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index 670519d2a1..6385e4cbf5 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Testing suite for the PyTorch Gemma model. """ - import tempfile import unittest @@ -24,6 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_to from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_read_token, require_torch, require_torch_gpu, require_torch_sdpa, @@ -529,6 +529,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi @require_torch_gpu @slow +@require_read_token class GemmaIntegrationTest(unittest.TestCase): input_text = ["Hello I am doing", "Hi today"]