Fix: Jamba batched generation (#32914)

* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
This commit is contained in:
Anton Vlasjuk
2024-08-28 09:24:06 +02:00
committed by GitHub
parent 386931d950
commit 3bfd3e4803
2 changed files with 50 additions and 56 deletions

View File

@@ -649,7 +649,12 @@ class JambaMambaMixer(nn.Module):
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
)
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: HybridMambaAttentionDynamicCache = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
@@ -666,6 +671,9 @@ class JambaMambaMixer(nn.Module):
# inner layernorms which isn't supported by this fused kernel
hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if use_precomputed_states:
@@ -683,6 +691,9 @@ class JambaMambaMixer(nn.Module):
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
@@ -742,14 +753,17 @@ class JambaMambaMixer(nn.Module):
return contextualized_states
# fmt: off
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1)
use_cache = isinstance(cache_params,HybridMambaAttentionDynamicCache)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
# 2. Convolution sequence transformation
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
if self.training:
@@ -784,6 +798,9 @@ class JambaMambaMixer(nn.Module):
)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
@@ -821,14 +838,19 @@ class JambaMambaMixer(nn.Module):
return contextualized_states
# fmt: on
def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
def forward(
self,
hidden_states,
cache_params: HybridMambaAttentionDynamicCache = None,
attention_mask: Optional[torch.LongTensor] = None,
):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
raise ValueError(
"Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
)
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
return self.slow_forward(hidden_states, cache_params, attention_mask)
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
@@ -1040,6 +1062,7 @@ class JambaMambaDecoderLayer(nn.Module):
hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
attention_mask=attention_mask,
)
self_attn_weights = None
@@ -1279,12 +1302,16 @@ class JambaModel(JambaPreTrainedModel):
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_router_logits = () if output_router_logits else None
for decoder_layer in self.layers:
# Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
if output_hidden_states:
all_hidden_states += (hidden_states,)
@@ -1292,7 +1319,7 @@ class JambaModel(JambaPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
layer_mask,
position_ids,
past_key_values,
output_attentions,
@@ -1303,7 +1330,7 @@ class JambaModel(JambaPreTrainedModel):
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
attention_mask=layer_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
@@ -1384,6 +1411,17 @@ class JambaModel(JambaPreTrainedModel):
return causal_mask
def _update_mamba_mask(self, attention_mask, cache_position):
"""
No need for zeroing states when
1. Cached forward
2. Attending to all inputs
"""
mamba_mask = attention_mask
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
mamba_mask = None
return mamba_mask
# Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
class JambaForCausalLM(JambaPreTrainedModel):

View File

@@ -458,51 +458,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_left_padding_compatibility(self):
r"""
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
"""
import inspect
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
# First, filter out models that don't support left padding - generative and decoder-only.
# Jamba is a decoder-only architecture
decoder_only_classes = self.all_generative_model_classes
# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs
for model_class in decoder_only_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
# With left-padding (length 32)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@@ -692,7 +647,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
EXPECTED_LOGITS_NO_GRAD = torch.tensor(
[
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279,
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292,
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906,
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562
@@ -737,10 +692,11 @@ class JambaModelIntegrationTest(unittest.TestCase):
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits
# TODO fix logits
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
[
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289,
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261,
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945,
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583
@@ -749,7 +705,7 @@ class JambaModelIntegrationTest(unittest.TestCase):
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
[
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874,
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852,
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129,
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891,