@@ -28,6 +28,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch,
|
require_torch,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
slow,
|
slow,
|
||||||
|
tooslow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -209,6 +210,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
# 8 is for A100 / A10 and 7 for T4
|
# 8 is for A100 / A10 and 7 for T4
|
||||||
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
|
||||||
|
|
||||||
|
@tooslow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_9b_bf16(self):
|
def test_model_9b_bf16(self):
|
||||||
model_id = "google/gemma-2-9b"
|
model_id = "google/gemma-2-9b"
|
||||||
@@ -229,6 +231,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
|
@tooslow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_model_9b_fp16(self):
|
def test_model_9b_fp16(self):
|
||||||
model_id = "google/gemma-2-9b"
|
model_id = "google/gemma-2-9b"
|
||||||
@@ -250,6 +253,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
@tooslow
|
||||||
def test_model_9b_pipeline_bf16(self):
|
def test_model_9b_pipeline_bf16(self):
|
||||||
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
# See https://github.com/huggingface/transformers/pull/31747 -- pipeline was broken for Gemma2 before this PR
|
||||||
model_id = "google/gemma-2-9b"
|
model_id = "google/gemma-2-9b"
|
||||||
@@ -296,6 +300,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
@slow
|
@slow
|
||||||
|
@tooslow
|
||||||
def test_model_9b_flash_attn(self):
|
def test_model_9b_flash_attn(self):
|
||||||
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
|
# See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
|
||||||
model_id = "google/gemma-2-9b"
|
model_id = "google/gemma-2-9b"
|
||||||
@@ -370,6 +375,7 @@ class Gemma2IntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
@tooslow
|
||||||
def test_model_9b_bf16_flex_attention(self):
|
def test_model_9b_bf16_flex_attention(self):
|
||||||
model_id = "google/gemma-2-9b"
|
model_id = "google/gemma-2-9b"
|
||||||
EXPECTED_TEXTS = [
|
EXPECTED_TEXTS = [
|
||||||
|
|||||||
Reference in New Issue
Block a user