From 8f1509a96c96747c893051ac947795cfb0750357 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:45:42 +0100 Subject: [PATCH] Fix more CI tests (#35661) add tooslow for the fat ones --- tests/models/gemma2/test_modeling_gemma2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 57c6331c8b..2bbe0d8e5b 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -28,6 +28,7 @@ from transformers.testing_utils import ( require_torch, require_torch_gpu, slow, + tooslow, torch_device, ) @@ -209,6 +210,7 @@ class Gemma2IntegrationTest(unittest.TestCase): # 8 is for A100 / A10 and 7 for T4 cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + @tooslow @require_read_token def test_model_9b_bf16(self): model_id = "google/gemma-2-9b" @@ -229,6 +231,7 @@ class Gemma2IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) + @tooslow @require_read_token def test_model_9b_fp16(self): model_id = "google/gemma-2-9b" @@ -250,6 +253,7 @@ class Gemma2IntegrationTest(unittest.TestCase): self.assertEqual(output_text, EXPECTED_TEXTS) @require_read_token + @tooslow def test_model_9b_pipeline_bf16(self): # See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR model_id = "google/gemma-2-9b" @@ -296,6 +300,7 @@ class Gemma2IntegrationTest(unittest.TestCase): @require_torch_gpu @mark.flash_attn_test @slow + @tooslow def test_model_9b_flash_attn(self): # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context model_id = "google/gemma-2-9b" @@ -370,6 +375,7 @@ class Gemma2IntegrationTest(unittest.TestCase): self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) @require_read_token + @tooslow def test_model_9b_bf16_flex_attention(self): model_id = "google/gemma-2-9b" EXPECTED_TEXTS = [