enable mamba2 integration cases on xpu (#38006)
* enable mamba2 integration cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -18,7 +18,14 @@ import unittest
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_read_token,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@@ -357,12 +364,18 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
|
||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||
output_sentence = tokenizer.decode(out[0])
|
||||
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
||||
ground_truth_sentences = Expectations(
|
||||
{
|
||||
("xpu", 3): """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program written in C++:\n\n```cpp\n#include <iostream>\n""",
|
||||
("cuda", 7): """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n""",
|
||||
}
|
||||
) # fmt: skip
|
||||
ground_truth_sentence = ground_truth_sentences.get_expectation()
|
||||
self.assertEqual(output_sentence, ground_truth_sentence)
|
||||
|
||||
@require_read_token
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_batched_equivalence_with_cache(self):
|
||||
"""
|
||||
Verifies that batched generation matches individual generation.
|
||||
@@ -393,7 +406,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_read_token
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_batched_equivalence_without_cache(self):
|
||||
"""
|
||||
Verifies that batched generation matches individual generation without cache.
|
||||
@@ -423,7 +436,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_mamba2_mixer_train_vs_eval_equivalence(self):
|
||||
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
|
||||
# Credit to zhixuan-lin
|
||||
@@ -433,10 +446,10 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
||||
|
||||
torch.manual_seed(42)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||
with torch.amp.autocast(device_type=torch_device, dtype=dtype):
|
||||
with torch.no_grad():
|
||||
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
|
||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||
mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device)
|
||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)
|
||||
|
||||
mixer.train()
|
||||
out_train = mixer(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user