enable mamba2 integration cases on xpu (#38006)

* enable mamba2 integration cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-05-09 03:48:09 +08:00
committed by GitHub
parent c7c2f08994
commit b3db4ddb22

View File

@@ -18,7 +18,14 @@ import unittest
from parameterized import parameterized
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
from transformers.testing_utils import (
Expectations,
require_read_token,
require_torch,
require_torch_accelerator,
slow,
torch_device,
)
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from ...generation.test_utils import GenerationTesterMixin
@@ -357,12 +364,18 @@ class Mamba2IntegrationTest(unittest.TestCase):
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
output_sentence = tokenizer.decode(out[0])
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
ground_truth_sentences = Expectations(
{
("xpu", 3): """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program written in C++:\n\n```cpp\n#include <iostream>\n""",
("cuda", 7): """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n""",
}
) # fmt: skip
ground_truth_sentence = ground_truth_sentences.get_expectation()
self.assertEqual(output_sentence, ground_truth_sentence)
@require_read_token
@slow
@require_torch_gpu
@require_torch_accelerator
def test_batched_equivalence_with_cache(self):
"""
Verifies that batched generation matches individual generation.
@@ -393,7 +406,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
@require_read_token
@slow
@require_torch_gpu
@require_torch_accelerator
def test_batched_equivalence_without_cache(self):
"""
Verifies that batched generation matches individual generation without cache.
@@ -423,7 +436,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
@slow
@require_torch_gpu
@require_torch_accelerator
def test_mamba2_mixer_train_vs_eval_equivalence(self):
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
# Credit to zhixuan-lin
@@ -433,10 +446,10 @@ class Mamba2IntegrationTest(unittest.TestCase):
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.amp.autocast(device_type=torch_device, dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device)
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)
mixer.train()
out_train = mixer(hidden_states)