enable mamba2 integration cases on xpu (#38006)
* enable mamba2 integration cases on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -18,7 +18,14 @@ import unittest
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||||
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
Expectations,
|
||||||
|
require_read_token,
|
||||||
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
@@ -357,12 +364,18 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
|
||||||
output_sentence = tokenizer.decode(out[0])
|
output_sentence = tokenizer.decode(out[0])
|
||||||
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
|
ground_truth_sentences = Expectations(
|
||||||
|
{
|
||||||
|
("xpu", 3): """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program written in C++:\n\n```cpp\n#include <iostream>\n""",
|
||||||
|
("cuda", 7): """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n""",
|
||||||
|
}
|
||||||
|
) # fmt: skip
|
||||||
|
ground_truth_sentence = ground_truth_sentences.get_expectation()
|
||||||
self.assertEqual(output_sentence, ground_truth_sentence)
|
self.assertEqual(output_sentence, ground_truth_sentence)
|
||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_batched_equivalence_with_cache(self):
|
def test_batched_equivalence_with_cache(self):
|
||||||
"""
|
"""
|
||||||
Verifies that batched generation matches individual generation.
|
Verifies that batched generation matches individual generation.
|
||||||
@@ -393,7 +406,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_read_token
|
@require_read_token
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_batched_equivalence_without_cache(self):
|
def test_batched_equivalence_without_cache(self):
|
||||||
"""
|
"""
|
||||||
Verifies that batched generation matches individual generation without cache.
|
Verifies that batched generation matches individual generation without cache.
|
||||||
@@ -423,7 +436,7 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
def test_mamba2_mixer_train_vs_eval_equivalence(self):
|
def test_mamba2_mixer_train_vs_eval_equivalence(self):
|
||||||
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
|
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
|
||||||
# Credit to zhixuan-lin
|
# Credit to zhixuan-lin
|
||||||
@@ -433,10 +446,10 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
with torch.amp.autocast(device_type=torch_device, dtype=dtype):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
|
mixer = Mamba2Mixer(config, layer_idx=0).to(torch_device)
|
||||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device=torch_device)
|
||||||
|
|
||||||
mixer.train()
|
mixer.train()
|
||||||
out_train = mixer(hidden_states)
|
out_train = mixer(hidden_states)
|
||||||
|
|||||||
Reference in New Issue
Block a user