From c269c5c74d08b00e8fc86c620997b6d639016aee Mon Sep 17 00:00:00 2001 From: Adibvafa Fallahpour <90617686+Adibvafa@users.noreply.github.com> Date: Tue, 1 Oct 2024 03:28:40 -0400 Subject: [PATCH] Fix Mamba slow path bug with dtype mismatch. (#32691) * Fix Mamba slow path bug with dtype mismatch. * Update test_modeling_mamba.py * Improve style. * Fix issue with cache position of dtype mismatch test. * Change test for slow path. * Revert changes. * Switch to buggy code and add test to catch it. * Fix the dtype mismatch bug and add test code to verify it. * Fix minor bug with test. * Fix incorrect dtype of model output. * Fix incorrect dtype of cache. * Fix incorrect dtype of ssm cache. * Fix incorrect dtype of conv state. * Remove assertion for ssm state. * Add assertion for conv state dtype. * Fix all issues with dtype mismatch test. --- src/transformers/cache_utils.py | 2 +- tests/models/mamba/test_modeling_mamba.py | 24 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d41bc99eea..0b82b17dcd 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1797,7 +1797,7 @@ class MambaCache: cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device) + conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) self.conv_states[layer_idx].zero_() self.conv_states[layer_idx] += conv_state return self.conv_states[layer_idx] diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 3b4a18bb48..d432dfa93d 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -421,6 +421,30 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi def test_beam_sample_generate(self): pass + def test_dtype_mismatch_handled_in_cache(self): + config, input_ids, *args = self.model_tester.prepare_config_and_inputs() + model = MambaModel(config) + model.to(torch_device).to(torch.float16) + model.eval() + + # Create cache with float32 dtype + cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) + + # If code is correct, no error occurs and test passes + outputs = model( + input_ids, + cache_params=cache_params, + use_cache=True, + cache_position=torch.arange(0, config.conv_kernel, device=input_ids.device), + ) + + self.assertIsNotNone(outputs) + self.assertIsNotNone(outputs.last_hidden_state) + self.assertEqual( + outputs.last_hidden_state.shape, + (self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.hidden_size), + ) + @require_torch class MambaIntegrationTests(unittest.TestCase):