[tests] enable GemmaIntegrationTest on XPU (#33555)

enable GemmaIntegrationTest
This commit is contained in:
Fanli Lin
2024-09-20 02:39:19 +08:00
committed by GitHub
parent b87755aa6d
commit 4d8908df27

View File

@@ -528,7 +528,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@slow
@require_torch_gpu
@require_torch_accelerator
class GemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
@@ -748,7 +748,6 @@ class GemmaIntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
@require_read_token
@@ -770,10 +769,8 @@ class GemmaIntegrationTest(unittest.TestCase):
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(output_text, EXPECTED_TEXTS)
@require_bitsandbytes