From 4005730044dd4de1575052eec24c1860ef97ccf9 Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Wed, 16 Apr 2025 16:16:01 -0400 Subject: [PATCH] Fix Mamba2 Grouped SSD Support in the torch_forward Path (#37533) * Fix mamba2 grouped support in bamba torch path * patch zamba2 and mamba2 * Add a unit test for grouped SSD * add comment for the new unit test * add output_size arg value to repeat_interleave calls * Add comment --- src/transformers/models/bamba/modeling_bamba.py | 4 ++-- src/transformers/models/bamba/modular_bamba.py | 4 ++-- src/transformers/models/mamba2/modeling_mamba2.py | 4 ++-- src/transformers/models/zamba2/modeling_zamba2.py | 4 ++-- src/transformers/models/zamba2/modular_zamba2.py | 4 ++-- tests/models/mamba2/test_modeling_mamba2.py | 8 ++++++++ 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 62496ba4eb..0fb696c1f7 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -783,8 +783,8 @@ class BambaMixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 8d64fdd989..3e0eb513d5 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -580,8 +580,8 @@ class BambaMixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 070af4016a..a1ca8d095c 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -572,8 +572,8 @@ class Mamba2Mixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 7854b04bc6..ad6ca0f872 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -860,8 +860,8 @@ class Zamba2MambaMixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 2c672bba5a..cd0af82658 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -630,8 +630,8 @@ class Zamba2MambaMixer(nn.Module): hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() - B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) - C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) + C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads) pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index d55655c6ac..31565bf23d 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -238,6 +238,14 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs) + # This test adjusts n_groups to half the original setting and effectively + # creates a grouped SSD configuration in the mamba2 layers + # See https://github.com/huggingface/transformers/pull/37533/ + def test_mamba2_slow_vs_fast_forward_grouped(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs[0].n_groups //= 2 + self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs) + def test_initialization(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common()