[Mamba2] Fix caching, slow path, and multi-gpu (#35154)

* fixup mamba2 - caching and several other small fixes

* fixup cached forward

* correct fix this time

* fixup cache - we do not need to extend the attn mask it's handled by generate (gives total ids + mask at each step)

* remove unnecessary (un)squeeze

* fixup cache position

* simplify a few things

* [run-slow] mamba2

* multi gpu attempt two

* [run-slow] mamba2

* [run-slow] mamba2

* [run-slow] mamba2

* [run-slow] mamba2

* add newer slow path fix

* [run-slow] mamba2
This commit is contained in:
Anton Vlasjuk
2024-12-20 03:27:47 -05:00
committed by GitHub
parent ff9141bb85
commit 5a2aedca1e
2 changed files with 263 additions and 177 deletions

View File

@@ -21,6 +21,7 @@ from parameterized import parameterized
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
@@ -103,6 +104,10 @@ class Mamba2ModelTester:
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
# Only left padding is valid
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
attention_mask[0, :1] = 0
sequence_labels = None
token_labels = None
choice_labels = None
@@ -118,7 +123,7 @@ class Mamba2ModelTester:
return (
config,
input_ids,
None,
attention_mask,
sequence_labels,
token_labels,
choice_labels,
@@ -158,6 +163,56 @@ class Mamba2ModelTester:
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
def create_and_check_mamba2_caching(self, config, input_ids, attention_mask, *args):
model = Mamba2Model(config=config)
model.to(torch_device)
model.eval()
output_whole = model(input_ids, attention_mask=attention_mask).last_hidden_state
outputs = model(
input_ids[:, :-1],
attention_mask=attention_mask[:, :-1],
use_cache=True,
cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device),
)
output_one = outputs.last_hidden_state
# Using the state computed on the first inputs, we will get the same output
outputs = model(
input_ids[:, -1:],
attention_mask=attention_mask[:, -1:],
use_cache=True,
cache_params=outputs.cache_params,
cache_position=torch.arange(config.conv_kernel, config.conv_kernel + 1, device=input_ids.device),
)
output_two = outputs.last_hidden_state
self.parent.assertTrue(
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-3, rtol=1e-3)
)
def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False):
model = Mamba2Model(config)
model.eval()
if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()):
self.parent.skipTest(
"This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found."
)
if torch_device != "cuda":
self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.")
model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
token_emb = model.embeddings(input_ids)
outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb)
outputs_slow = model.layers[0].mixer.torch_forward(token_emb)
self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))
@unittest.skipIf(
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
@@ -184,6 +239,14 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)
def test_mamba2_caching(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba2_caching(*config_and_inputs)
def test_mamba2_slow_vs_fast_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
@@ -199,23 +262,6 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
@parameterized.expand([("greedy", 1), ("beam search", 2)])
def test_generate_from_inputs_embeds(self, _, num_beams):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass