[Mamba2] Move dt calculations to kernel (#33520)
* use kernel for dt calculations * add small test * [run-slow] mamba2
This commit is contained in:
@@ -358,7 +358,6 @@ class Mamba2Mixer(nn.Module):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
time_step = nn.functional.softplus(time_step + self.dt_bias)
|
|
||||||
# 1D Convolution
|
# 1D Convolution
|
||||||
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
|
||||||
hidden_states_B_C = self.act(
|
hidden_states_B_C = self.act(
|
||||||
@@ -391,6 +390,8 @@ class Mamba2Mixer(nn.Module):
|
|||||||
z=None,
|
z=None,
|
||||||
seq_idx=None,
|
seq_idx=None,
|
||||||
return_final_states=True,
|
return_final_states=True,
|
||||||
|
dt_bias=self.dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
**dt_limit_kwargs,
|
**dt_limit_kwargs,
|
||||||
)
|
)
|
||||||
if ssm_state is not None and cache_params is not None:
|
if ssm_state is not None and cache_params is not None:
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ if is_torch_available():
|
|||||||
Mamba2ForCausalLM,
|
Mamba2ForCausalLM,
|
||||||
Mamba2Model,
|
Mamba2Model,
|
||||||
)
|
)
|
||||||
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
|
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
|
||||||
else:
|
else:
|
||||||
is_torch_greater_or_equal_than_2_0 = False
|
is_torch_greater_or_equal_than_2_0 = False
|
||||||
@@ -378,3 +378,27 @@ class Mamba2IntegrationTest(unittest.TestCase):
|
|||||||
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
|
||||||
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
|
||||||
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_mamba2_mixer_train_vs_eval_equivalence(self):
|
||||||
|
# Based on https://github.com/sustcsonglin/flash-linear-attention/issues/63
|
||||||
|
# Credit to zhixuan-lin
|
||||||
|
|
||||||
|
B, T, D = 4, 512, 768
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
config = Mamba2Config(num_heads=24, head_dim=64, hidden_size=768, expand=2, n_groups=1)
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
with torch.amp.autocast(device_type="cuda", dtype=dtype):
|
||||||
|
with torch.no_grad():
|
||||||
|
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
|
||||||
|
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
mixer.train()
|
||||||
|
out_train = mixer(hidden_states)
|
||||||
|
|
||||||
|
mixer.eval()
|
||||||
|
out_eval = mixer(hidden_states)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(out_train, out_eval, atol=1e-3))
|
||||||
|
|||||||
Reference in New Issue
Block a user