[Mamba2] Fix caching, slow path, and multi-gpu (#35154)
* fixup mamba2 - caching and several other small fixes * fixup cached forward * correct fix this time * fixup cache - we do not need to extend the attn mask it's handled by generate (gives total ids + mask at each step) * remove unnecessary (un)squeeze * fixup cache position * simplify a few things * [run-slow] mamba2 * multi gpu attempt two * [run-slow] mamba2 * [run-slow] mamba2 * [run-slow] mamba2 * [run-slow] mamba2 * add newer slow path fix * [run-slow] mamba2
This commit is contained in:
@@ -21,6 +21,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
|
||||
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -103,6 +104,10 @@ class Mamba2ModelTester:
|
||||
):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
# Only left padding is valid
|
||||
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
|
||||
attention_mask[0, :1] = 0
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
@@ -118,7 +123,7 @@ class Mamba2ModelTester:
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
None,
|
||||
attention_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
@@ -158,6 +163,56 @@ class Mamba2ModelTester:
|
||||
inputs_dict = {"input_ids": input_ids}
|
||||
return config, inputs_dict
|
||||
|
||||
def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args):
|
||||
model = Mamba2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state
|
||||
|
||||
outputs = model(
|
||||
input_ids[:, :-1],
|
||||
attention_mask=attention_mask[:, :-1],
|
||||
use_cache=True,
|
||||
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
|
||||
)
|
||||
output_one = outputs.last_hidden_state
|
||||
|
||||
# Using the state computed on the first inputs, we will get the same output
|
||||
outputs = model(
|
||||
input_ids[:, -1:],
|
||||
attention_mask=attention_mask[:, -1:],
|
||||
use_cache=True,
|
||||
cache_params=outputs.cache_params,
|
||||
cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device),
|
||||
)
|
||||
output_two = outputs.last_hidden_state
|
||||
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3)
|
||||
)
|
||||
|
||||
def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False):
|
||||
model = Mamba2Model(config)
|
||||
model.eval()
|
||||
|
||||
if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()):
|
||||
self.parent.skipTest(
|
||||
"This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found."
|
||||
)
|
||||
if torch_device != "cuda":
|
||||
self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.")
|
||||
|
||||
model.to(torch_device)
|
||||
if gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
token_emb = model.embeddings(input_ids)
|
||||
outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb)
|
||||
outputs_slow = model.layers[0].mixer.torch_forward(token_emb)
|
||||
|
||||
self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
|
||||
@@ -184,6 +239,14 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
|
||||
)
|
||||
|
||||
def test_mamba2_caching(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba2_caching(*config_and_inputs)
|
||||
|
||||
def test_mamba2_slow_vs_fast_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)
|
||||
|
||||
def test_initialization(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -199,23 +262,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
||||
def test_tied_weights_keys(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
@parameterized.expand([("greedy", 1), ("beam search", 2)])
|
||||
def test_generate_from_inputs_embeds(self, _, num_beams):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
|
||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
|
||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
|
||||
def test_multi_gpu_data_parallel_forward(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user