enable 6 modeling cases on XPU (#37571)
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -19,7 +19,14 @@ import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
||||
from transformers.testing_utils import Expectations, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_deterministic_for_xpu,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -474,7 +481,7 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
class BambaModelIntegrationTest(unittest.TestCase):
|
||||
model = None
|
||||
tokenizer = None
|
||||
@@ -507,6 +514,10 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
"rocm",
|
||||
9,
|
||||
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
(
|
||||
"xpu",
|
||||
3,
|
||||
): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all doing well. Today I",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -536,22 +547,30 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1)
|
||||
|
||||
@require_deterministic_for_xpu
|
||||
def test_simple_batched_generate_with_padding(self):
|
||||
# Key 9 for MI300, Key 8 for A100/A10, and Key 7 for T4.
|
||||
#
|
||||
# Note: Key 9 is currently set for MI300, but may need potential future adjustments for H100s,
|
||||
# considering differences in hardware processing and potential deviations in generated text.
|
||||
EXPECTED_TEXTS = {
|
||||
7: [],
|
||||
8: [
|
||||
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
"!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the",
|
||||
],
|
||||
9: [
|
||||
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
"!!!<|begin_of_text|>I am late! I need to be at the airport in 20 minutes! I",
|
||||
],
|
||||
}
|
||||
EXPECTED_TEXTS = Expectations(
|
||||
{
|
||||
("cuda", 7): [],
|
||||
("cuda", 8): [
|
||||
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
"!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the",
|
||||
],
|
||||
("rocm", 9): [
|
||||
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here",
|
||||
"!!!<|begin_of_text|>I am late! I need to be at the airport in 20 minutes! I",
|
||||
],
|
||||
("xpu", 3): [
|
||||
"<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all doing well. Today I",
|
||||
"!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the",
|
||||
],
|
||||
}
|
||||
)
|
||||
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
||||
|
||||
self.model.to(torch_device)
|
||||
|
||||
@@ -562,8 +581,8 @@ class BambaModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
output_sentences = self.tokenizer.batch_decode(out)
|
||||
self.assertEqual(output_sentences[0], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][0])
|
||||
self.assertEqual(output_sentences[1], EXPECTED_TEXTS[self.cuda_compute_capability_major_version][1])
|
||||
self.assertEqual(output_sentences[0], EXPECTED_TEXT[0])
|
||||
self.assertEqual(output_sentences[1], EXPECTED_TEXT[1])
|
||||
|
||||
# TODO: there are significant differences in the logits across major cuda versions, which shouldn't exist
|
||||
if self.cuda_compute_capability_major_version == 8:
|
||||
|
||||
Reference in New Issue
Block a user