[tests] enable GemmaIntegrationTest on XPU (#33555)
enable GemmaIntegrationTest
This commit is contained in:
@@ -528,7 +528,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class GemmaIntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
|
||||
@@ -748,7 +748,6 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS[self.cuda_compute_capability_major_version])
|
||||
|
||||
@require_read_token
|
||||
@@ -770,10 +769,8 @@ class GemmaIntegrationTest(unittest.TestCase):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
|
||||
|
||||
output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
|
||||
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
|
||||
|
||||
self.assertEqual(output_text, EXPECTED_TEXTS)
|
||||
|
||||
@require_bitsandbytes
|
||||
|
||||
Reference in New Issue
Block a user