Fix more CI tests (#35661)

add tooslow for the fat ones
This commit is contained in:
Arthur
2025-01-23 14:45:42 +01:00
committed by GitHub
parent 0a950e0bbe
commit 8f1509a96c

View File

@@ -28,6 +28,7 @@ from transformers.testing_utils import (
require_torch,
require_torch_gpu,
slow,
tooslow,
torch_device,
)
@@ -209,6 +210,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
@tooslow
@require_read_token
def test_model_9b_bf16(self):
model_id = "google/gemma-2-9b"
@@ -229,6 +231,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@tooslow
@require_read_token
def test_model_9b_fp16(self):
model_id = "google/gemma-2-9b"
@@ -250,6 +253,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_read_token
@tooslow
def test_model_9b_pipeline_bf16(self):
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
model_id = "google/gemma-2-9b"
@@ -296,6 +300,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
@require_torch_gpu
@mark.flash_attn_test
@slow
@tooslow
def test_model_9b_flash_attn(self):
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
model_id = "google/gemma-2-9b"
@@ -370,6 +375,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
@require_read_token
@tooslow
def test_model_9b_bf16_flex_attention(self):
model_id = "google/gemma-2-9b"
EXPECTED_TEXTS = [