Expectation fixes and added AMD expectations (#38729)
This commit is contained in:
@@ -21,6 +21,7 @@ from packaging import version
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
DeviceProperties,
|
||||
Expectations,
|
||||
cleanup,
|
||||
get_device_properties,
|
||||
@@ -108,7 +109,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
# This variable is used to determine which accelerator are we using for our runners (e.g. A10 or T4)
|
||||
# Depending on the hardware we get different logits / generations
|
||||
device_properties = None
|
||||
device_properties: DeviceProperties = (None, None, None)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -241,7 +242,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_fp16(self):
|
||||
if self.device_properties == ("cuda", 7):
|
||||
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||
|
||||
model_id = "google/gemma-7b"
|
||||
@@ -262,7 +263,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_bf16(self):
|
||||
if self.device_properties == ("cuda", 7):
|
||||
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||
|
||||
model_id = "google/gemma-7b"
|
||||
@@ -293,7 +294,7 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_read_token
|
||||
def test_model_7b_fp16_static_cache(self):
|
||||
if self.device_properties == ("cuda", 7):
|
||||
if self.device_properties[0] == "cuda" and self.device_properties[1] == 7:
|
||||
self.skipTest("This test is failing (`torch.compile` fails) on Nvidia T4 GPU (OOM).")
|
||||
|
||||
model_id = "google/gemma-7b"
|
||||
|
||||
Reference in New Issue
Block a user