[Mamba2] Move dt calculations to kernel (#33520)

* use kernel for dt calculations

* add small test

* [run-slow] mamba2
This commit is contained in:
Anton Vlasjuk
2024-09-19 18:41:17 +02:00
committed by GitHub
parent 162056a3f4
commit b50ff5993a
2 changed files with 27 additions and 2 deletions

View File

@@ -358,7 +358,6 @@ class Mamba2Mixer(nn.Module):
dim=-1, dim=-1,
) )
time_step = nn.functional.softplus(time_step + self.dt_bias)
# 1D Convolution # 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]: if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act( hidden_states_B_C = self.act(
@@ -391,6 +390,8 @@ class Mamba2Mixer(nn.Module):
z=None, z=None,
seq_idx=None, seq_idx=None,
return_final_states=True, return_final_states=True,
dt_bias=self.dt_bias,
dt_softplus=True,
**dt_limit_kwargs, **dt_limit_kwargs,
) )
if ssm_state is not None and cache_params is not None: if ssm_state is not None and cache_params is not None:

View File

@@ -35,7 +35,7 @@ if is_torch_available():
Mamba2ForCausalLM, Mamba2ForCausalLM,
Mamba2Model, Mamba2Model,
) )
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
else: else:
is_torch_greater_or_equal_than_2_0 = False is_torch_greater_or_equal_than_2_0 = False
@@ -378,3 +378,27 @@ class Mamba2IntegrationTest(unittest.TestCase):
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True) individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0] individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100]) self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
@slow
@require_torch_gpu
def test_mamba2_mixer_train_vs_eval_equivalence(self):
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
# Credit to zhixuan-lin
B, T, D = 4, 512, 768
dtype = torch.bfloat16
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
mixer.train()
out_train = mixer(hidden_states)
mixer.eval()
out_eval = mixer(hidden_states)
self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))