FIX [Gemma / CI] Make sure our runners have access to the model (#29242)
* pu hf token in gemma tests * update suggestion * add to flax * revert * fix * fixup * forward contrib credits from discussion --------- Co-authored-by: ArthurZucker <ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -31,12 +31,14 @@ import time
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from functools import wraps
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import huggingface_hub
|
||||
import urllib3
|
||||
|
||||
from transformers import logging as transformers_logging
|
||||
@@ -460,6 +462,20 @@ def require_torch_sdpa(test_case):
|
||||
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
|
||||
|
||||
|
||||
def require_read_token(fn):
|
||||
"""
|
||||
A decorator that loads the HF token for tests that require to load gated models.
|
||||
"""
|
||||
token = os.getenv("HF_HUB_READ_TOKEN", None)
|
||||
|
||||
@wraps(fn)
|
||||
def _inner(*args, **kwargs):
|
||||
with patch(huggingface_hub.utils._headers, "get_token", return_value=token):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT.
|
||||
|
||||
@@ -11,14 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import AutoTokenizer, GemmaConfig, is_flax_available
|
||||
from transformers.testing_utils import require_flax, slow
|
||||
from transformers.testing_utils import require_flax, require_read_token, slow
|
||||
|
||||
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
|
||||
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
|
||||
@@ -205,6 +203,7 @@ class FlaxGemmaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitte
|
||||
|
||||
@slow
|
||||
@require_flax
|
||||
@require_read_token
|
||||
class FlaxGemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["The capital of France is", "To play the perfect cover drive"]
|
||||
model_id = "google/gemma-2b"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Gemma model. """
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -24,6 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_to
|
||||
from transformers.testing_utils import (
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_sdpa,
|
||||
@@ -529,6 +529,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
@require_read_token
|
||||
class GemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user