From cefb819f7a54009b36493b90878ee9b3b198039f Mon Sep 17 00:00:00 2001 From: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Date: Wed, 27 Mar 2024 04:52:12 +0100 Subject: [PATCH] Mamba `slow_forward` gradient fix (#29563) * FIX: Cached slow forward in mamba - additionally added mamba cached test - added unused test (mamba causal lm forward and backward) - fixed typo: "causl" --> "causal" * formatting * fix: use real `slow_forward` call instead of torch module's * add shape assertion for mixer block test * adjust shape assertion --- .../models/mamba/modeling_mamba.py | 2 +- tests/models/mamba/test_modeling_mamba.py | 37 +++++++++++++++++-- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 0e233ae430..00e51e5090 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -230,7 +230,7 @@ class MambaMixer(nn.Module): # 2. Convolution sequence transformation if cache_params is not None: - ssm_state = cache_params.ssm_states[self.layer_idx] + ssm_state = cache_params.ssm_states[self.layer_idx].clone() if cache_params.seqlen_offset > 0: conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 8bd121933b..3b77e26dcc 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -170,7 +170,7 @@ class MambaModelTester: self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1) - def create_and_check_causl_lm(self, config, input_ids, *args): + def create_and_check_causal_lm(self, config, input_ids, *args): model = MambaForCausalLM(config) model.to(torch_device) model.eval() @@ -197,7 +197,30 @@ class MambaModelTester: self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) # TODO the orignal mamba does not support decoding more than 1 token neither do we - def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False): + def create_and_check_mamba_cached_slow_forward_and_backwards( + self, config, input_ids, *args, gradient_checkpointing=False + ): + model = MambaModel(config) + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + # create cache + cache = model(input_ids, use_cache=True).cache_params + cache.seqlen_offset = 0 + + # use cache + token_emb = model.embeddings(input_ids) + outputs = model.layers[0].mixer.slow_forward(token_emb, cache) + + loss = torch.log(1 + torch.abs(outputs.sum())) + self.parent.assertEqual(loss.shape, ()) + self.parent.assertEqual(outputs.shape, (self.batch_size, self.seq_length, self.hidden_size)) + loss.backward() + + def create_and_check_mamba_lm_head_forward_and_backwards( + self, config, input_ids, *args, gradient_checkpointing=False + ): model = MambaForCausalLM(config) model.to(torch_device) if gradient_checkpointing: @@ -304,12 +327,20 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_mamba_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_causl_lm(*config_and_inputs) + self.model_tester.create_and_check_causal_lm(*config_and_inputs) def test_state_equivalency(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_state_equivalency(*config_and_inputs) + def test_mamba_cached_slow_forward_and_backwards(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba_cached_slow_forward_and_backwards(*config_and_inputs) + + def test_mamba_lm_head_forward_and_backwards(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba_lm_head_forward_and_backwards(*config_and_inputs) + def test_initialization(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common()