@@ -28,6 +28,7 @@ from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
tooslow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -209,6 +210,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
# 8 is for A100 / A10 and 7 for T4
|
||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||
|
||||
@tooslow
|
||||
@require_read_token
|
||||
def test_model_9b_bf16(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
@@ -229,6 +231,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@tooslow
|
||||
@require_read_token
|
||||
def test_model_9b_fp16(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
@@ -250,6 +253,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_read_token
|
||||
@tooslow
|
||||
def test_model_9b_pipeline_bf16(self):
|
||||
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
||||
model_id = "google/gemma-2-9b"
|
||||
@@ -296,6 +300,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
@slow
|
||||
@tooslow
|
||||
def test_model_9b_flash_attn(self):
|
||||
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
|
||||
model_id = "google/gemma-2-9b"
|
||||
@@ -370,6 +375,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||
|
||||
@require_read_token
|
||||
@tooslow
|
||||
def test_model_9b_bf16_flex_attention(self):
|
||||
model_id = "google/gemma-2-9b"
|
||||
EXPECTED_TEXTS = [
|
||||
|
||||
Reference in New Issue
Block a user