Fix: Jamba batched generation (#32914)
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
This commit is contained in:
@@ -649,7 +649,12 @@ class JambaMambaMixer(nn.Module):
|
|||||||
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
|
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
|
||||||
)
|
)
|
||||||
|
|
||||||
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
|
def cuda_kernels_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
cache_params: HybridMambaAttentionDynamicCache = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
batch_size, seq_len, _ = hidden_states.shape
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
use_precomputed_states = (
|
use_precomputed_states = (
|
||||||
cache_params is not None
|
cache_params is not None
|
||||||
@@ -666,6 +671,9 @@ class JambaMambaMixer(nn.Module):
|
|||||||
# inner layernorms which isn't supported by this fused kernel
|
# inner layernorms which isn't supported by this fused kernel
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
||||||
if use_precomputed_states:
|
if use_precomputed_states:
|
||||||
@@ -683,6 +691,9 @@ class JambaMambaMixer(nn.Module):
|
|||||||
cache_params.conv_states[self.layer_idx].copy_(conv_states)
|
cache_params.conv_states[self.layer_idx].copy_(conv_states)
|
||||||
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
|
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
# 3.a. input varying initialization of time_step, B and C
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||||
@@ -742,14 +753,17 @@ class JambaMambaMixer(nn.Module):
|
|||||||
return contextualized_states
|
return contextualized_states
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
|
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask: Optional[torch.LongTensor] = None):
|
||||||
batch_size, seq_len, _ = input_states.shape
|
batch_size, seq_len, _ = input_states.shape
|
||||||
dtype = input_states.dtype
|
dtype = input_states.dtype
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
|
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
|
||||||
if self.training:
|
if self.training:
|
||||||
@@ -784,6 +798,9 @@ class JambaMambaMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||||
@@ -821,14 +838,19 @@ class JambaMambaMixer(nn.Module):
|
|||||||
return contextualized_states
|
return contextualized_states
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
cache_params: HybridMambaAttentionDynamicCache = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
if self.use_fast_kernels:
|
if self.use_fast_kernels:
|
||||||
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
|
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
|
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
|
||||||
)
|
)
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params)
|
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
|
||||||
return self.slow_forward(hidden_states, cache_params)
|
return self.slow_forward(hidden_states, cache_params, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
|
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
|
||||||
@@ -1040,6 +1062,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.mamba(
|
hidden_states = self.mamba(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
cache_params=past_key_value,
|
cache_params=past_key_value,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
self_attn_weights = None
|
self_attn_weights = None
|
||||||
|
|
||||||
@@ -1279,12 +1302,16 @@ class JambaModel(JambaPreTrainedModel):
|
|||||||
position_ids = cache_position.unsqueeze(0)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
||||||
|
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
|
||||||
|
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_router_logits = () if output_router_logits else None
|
all_router_logits = () if output_router_logits else None
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for decoder_layer in self.layers:
|
||||||
|
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
|
||||||
|
layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
@@ -1292,7 +1319,7 @@ class JambaModel(JambaPreTrainedModel):
|
|||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
causal_mask,
|
layer_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
past_key_values,
|
past_key_values,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
@@ -1303,7 +1330,7 @@ class JambaModel(JambaPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask,
|
attention_mask=layer_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -1384,6 +1411,17 @@ class JambaModel(JambaPreTrainedModel):
|
|||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
def _update_mamba_mask(self, attention_mask, cache_position):
|
||||||
|
"""
|
||||||
|
No need for zeroing states when
|
||||||
|
1. Cached forward
|
||||||
|
2. Attending to all inputs
|
||||||
|
"""
|
||||||
|
mamba_mask = attention_mask
|
||||||
|
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
|
||||||
|
mamba_mask = None
|
||||||
|
return mamba_mask
|
||||||
|
|
||||||
|
|
||||||
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
|
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
|
||||||
class JambaForCausalLM(JambaPreTrainedModel):
|
class JambaForCausalLM(JambaPreTrainedModel):
|
||||||
|
|||||||
@@ -458,51 +458,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_left_padding_compatibility(self):
|
|
||||||
r"""
|
|
||||||
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
|
|
||||||
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
|
|
||||||
"""
|
|
||||||
import inspect
|
|
||||||
# NOTE: left-padding results in small numerical differences. This is expected.
|
|
||||||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
|
|
||||||
|
|
||||||
# First, filter out models that don't support left padding - generative and decoder-only.
|
|
||||||
# Jamba is a decoder-only architecture
|
|
||||||
decoder_only_classes = self.all_generative_model_classes
|
|
||||||
|
|
||||||
# Then, test left-padding
|
|
||||||
def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
|
||||||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
||||||
if "position_ids" in signature:
|
|
||||||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
|
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
||||||
model_kwargs["position_ids"] = position_ids
|
|
||||||
if "cache_position" in signature:
|
|
||||||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
|
|
||||||
model_kwargs["cache_position"] = cache_position
|
|
||||||
return model_kwargs
|
|
||||||
|
|
||||||
for model_class in decoder_only_classes:
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
|
||||||
signature = inspect.signature(model.forward).parameters.keys()
|
|
||||||
|
|
||||||
# Without padding
|
|
||||||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
|
|
||||||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
|
|
||||||
|
|
||||||
# With left-padding (length 32)
|
|
||||||
pad_size = (input_ids.shape[0], 32)
|
|
||||||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
|
|
||||||
padded_input_ids = torch.cat((padding, input_ids), dim=1)
|
|
||||||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
|
|
||||||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
|
|
||||||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
|
|
||||||
|
|
||||||
# They should result in very similar logits
|
|
||||||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
|
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_bitsandbytes
|
@require_bitsandbytes
|
||||||
@@ -692,7 +647,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
|
||||||
[
|
[
|
||||||
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
|
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
|
||||||
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
|
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
|
||||||
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
|
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
|
||||||
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
|
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
|
||||||
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
|
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
|
||||||
@@ -737,10 +692,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = self.model(input_ids=inputs["input_ids"]).logits
|
logits = self.model(input_ids=inputs["input_ids"]).logits
|
||||||
|
|
||||||
|
# TODO fix logits
|
||||||
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
|
||||||
[
|
[
|
||||||
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
|
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
|
||||||
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
|
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
|
||||||
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
|
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
|
||||||
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
|
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
|
||||||
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
|
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
|
||||||
@@ -749,7 +705,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||||
[
|
[
|
||||||
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
|
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
|
||||||
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
|
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
|
||||||
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
|
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
|
||||||
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
|
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,
|
||||||
|
|||||||
Reference in New Issue
Block a user