From b3db4ddb2255bb4c8c4340fa630a53ac1cc53dee Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 9 May 2025 03:48:09 +0800 Subject: [PATCH] enable mamba2 integration cases on xpu (#38006) * enable mamba2 integration cases on XPU Signed-off-by: Yao Matrix * fix style Signed-off-by: Yao Matrix --------- Signed-off-by: Yao Matrix --- tests/models/mamba2/test_modeling_mamba2.py | 29 +++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 31565bf23d..6d9b98cced 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -18,7 +18,14 @@ import unittest from parameterized import parameterized from transformers import AutoTokenizer, Mamba2Config, is_torch_available -from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device +from transformers.testing_utils import ( + Expectations, + require_read_token, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from ...generation.test_utils import GenerationTesterMixin @@ -357,12 +364,18 @@ class Mamba2IntegrationTest(unittest.TestCase): out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30) output_sentence = tokenizer.decode(out[0]) - ground_truth_sentence = """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""" + ground_truth_sentences = Expectations( + { + ("xpu", 3): """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program written in C++:\n\n```cpp\n#include \n""", + ("cuda", 7): """[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include \n\n""", + } + ) # fmt: skip + ground_truth_sentence = ground_truth_sentences.get_expectation() self.assertEqual(output_sentence, ground_truth_sentence) @require_read_token @slow - @require_torch_gpu + @require_torch_accelerator def test_batched_equivalence_with_cache(self): """ Verifies that batched generation matches individual generation. @@ -393,7 +406,7 @@ class Mamba2IntegrationTest(unittest.TestCase): @require_read_token @slow - @require_torch_gpu + @require_torch_accelerator def test_batched_equivalence_without_cache(self): """ Verifies that batched generation matches individual generation without cache. @@ -423,7 +436,7 @@ class Mamba2IntegrationTest(unittest.TestCase): self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) @slow - @require_torch_gpu + @require_torch_accelerator def test_mamba2_mixer_train_vs_eval_equivalence(self): # Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63 # Credit to zhixuan-lin @@ -433,10 +446,10 @@ class Mamba2IntegrationTest(unittest.TestCase): config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1) torch.manual_seed(42) - with torch.amp.autocast(device_type="cuda", dtype=dtype): + with torch.amp.autocast(device_type=torch_device, dtype=dtype): with torch.no_grad(): - mixer = Mamba2Mixer(config, layer_idx=0).to("cuda") - hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda") + mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device) + hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device) mixer.train() out_train = mixer(hidden_states)