[Mamba2] Move dt calculations to kernel (#33520)
* use kernel for dt calculations * add small test * [run-slow] mamba2
This commit is contained in:
@@ -358,7 +358,6 @@ class Mamba2Mixer(nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
time_step = nn.functional.softplus(time_step + self.dt_bias)
|
||||
# 1D Convolution
|
||||
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
||||
hidden_states_B_C = self.act(
|
||||
@@ -391,6 +390,8 @@ class Mamba2Mixer(nn.Module):
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
return_final_states=True,
|
||||
dt_bias=self.dt_bias,
|
||||
dt_softplus=True,
|
||||
**dt_limit_kwargs,
|
||||
)
|
||||
if ssm_state is not None and cache_params is not None:
|
||||
|
||||
@@ -35,7 +35,7 @@ if is_torch_available():
|
||||
Mamba2ForCausalLM,
|
||||
Mamba2Model,
|
||||
)
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
|
||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||
else:
|
||||
is_torch_greater_or_equal_than_2_0 = False
|
||||
@@ -378,3 +378,27 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
||||
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
||||
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_mamba2_mixer_train_vs_eval_equivalence(self):
|
||||
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
|
||||
# Credit to zhixuan-lin
|
||||
|
||||
B, T, D = 4, 512, 768
|
||||
dtype = torch.bfloat16
|
||||
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
||||
|
||||
torch.manual_seed(42)
|
||||
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||
with torch.no_grad():
|
||||
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
|
||||
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||
|
||||
mixer.train()
|
||||
out_train = mixer(hidden_states)
|
||||
|
||||
mixer.eval()
|
||||
out_eval = mixer(hidden_states)
|
||||
|
||||
self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user