Mamba / FalconMamba: Fix mamba left padding (#32677)
* fix mamba left padding * Apply suggestions from code review Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> * fix copies * test with `inputs_embeds` * Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * copies * clairfy * fix last comments * remove --------- Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -155,6 +155,7 @@ class FalconMambaMixer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
||||||
@@ -179,6 +180,9 @@ class FalconMambaMixer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
||||||
if cache_params is not None and cache_position[0] > 0:
|
if cache_params is not None and cache_position[0] > 0:
|
||||||
@@ -200,6 +204,9 @@ class FalconMambaMixer(nn.Module):
|
|||||||
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
# 3.a. input varying initialization of time_step, B and C
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||||
@@ -259,6 +266,7 @@ class FalconMambaMixer(nn.Module):
|
|||||||
input_states,
|
input_states,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
batch_size, seq_len, _ = input_states.shape
|
batch_size, seq_len, _ = input_states.shape
|
||||||
dtype = input_states.dtype
|
dtype = input_states.dtype
|
||||||
@@ -266,6 +274,9 @@ class FalconMambaMixer(nn.Module):
|
|||||||
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
||||||
@@ -294,6 +305,9 @@ class FalconMambaMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||||
@@ -355,10 +369,11 @@ class FalconMambaMixer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
return self.slow_forward(hidden_states, cache_params, cache_position)
|
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
|
# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
|
||||||
@@ -396,13 +411,16 @@ class FalconMambaBlock(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
||||||
if self.residual_in_fp32:
|
if self.residual_in_fp32:
|
||||||
residual = residual.to(torch.float32)
|
residual = residual.to(torch.float32)
|
||||||
|
|
||||||
hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
|
hidden_states = self.mixer(
|
||||||
|
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
|
||||||
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -601,14 +619,13 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
|
|
||||||
inputs_embeds: Optional[torch.LongTensor] = None,
|
inputs_embeds: Optional[torch.LongTensor] = None,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, FalconMambaOutput]:
|
) -> Union[Tuple, FalconMambaOutput]:
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -649,10 +666,15 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
|
|||||||
for mixer_block in self.layers:
|
for mixer_block in self.layers:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
hidden_states = self._gradient_checkpointing_func(
|
hidden_states = self._gradient_checkpointing_func(
|
||||||
mixer_block.__call__, hidden_states, cache_params, cache_position
|
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
|
hidden_states = mixer_block(
|
||||||
|
hidden_states,
|
||||||
|
cache_params=cache_params,
|
||||||
|
cache_position=cache_position,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
@@ -712,6 +734,13 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
|
|||||||
and model_kwargs["cache_position"] is not None
|
and model_kwargs["cache_position"] is not None
|
||||||
):
|
):
|
||||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||||
|
|
||||||
|
if "attention_mask" in model_kwargs:
|
||||||
|
attention_mask = model_kwargs["attention_mask"]
|
||||||
|
model_kwargs["attention_mask"] = torch.cat(
|
||||||
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
@@ -721,6 +750,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@@ -733,6 +763,10 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
if cache_position[0] > 0:
|
if cache_position[0] > 0:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||||
# considering padding will be applied when input length is shorter, and truncation
|
# considering padding will be applied when input length is shorter, and truncation
|
||||||
@@ -750,6 +784,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
|
|||||||
"cache_params": cache_params,
|
"cache_params": cache_params,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
@@ -760,11 +795,10 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
|
|||||||
output_type=FalconMambaCausalLMOutput,
|
output_type=FalconMambaCausalLMOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
)
|
)
|
||||||
# Ignore copy
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.LongTensor] = None, # Ignored copy
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
@@ -790,6 +824,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
hidden_states = falcon_mamba_outputs[0]
|
hidden_states = falcon_mamba_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ class MambaMixer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
projected_states = self.in_proj(hidden_states).transpose(1, 2)
|
||||||
@@ -160,6 +161,9 @@ class MambaMixer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
|
||||||
if cache_params is not None and cache_position[0] > 0:
|
if cache_params is not None and cache_position[0] > 0:
|
||||||
@@ -181,6 +185,9 @@ class MambaMixer(nn.Module):
|
|||||||
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. input varying initialization of time_step, B and C
|
# 3.a. input varying initialization of time_step, B and C
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||||
@@ -226,13 +233,16 @@ class MambaMixer(nn.Module):
|
|||||||
return contextualized_states
|
return contextualized_states
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None):
|
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
|
||||||
batch_size, seq_len, _ = input_states.shape
|
batch_size, seq_len, _ = input_states.shape
|
||||||
dtype = input_states.dtype
|
dtype = input_states.dtype
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
|
||||||
hidden_states, gate = projected_states.chunk(2, dim=1)
|
hidden_states, gate = projected_states.chunk(2, dim=1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 2. Convolution sequence transformation
|
# 2. Convolution sequence transformation
|
||||||
if cache_params is not None:
|
if cache_params is not None:
|
||||||
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
|
||||||
@@ -261,6 +271,9 @@ class MambaMixer(nn.Module):
|
|||||||
)
|
)
|
||||||
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
hidden_states = hidden_states * attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
|
||||||
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
|
||||||
@@ -306,10 +319,11 @@ class MambaMixer(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
|
||||||
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
|
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
return self.slow_forward(hidden_states, cache_params, cache_position)
|
return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
|
||||||
|
|
||||||
|
|
||||||
class MambaRMSNorm(nn.Module):
|
class MambaRMSNorm(nn.Module):
|
||||||
@@ -346,13 +360,16 @@ class MambaBlock(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
|
||||||
if self.residual_in_fp32:
|
if self.residual_in_fp32:
|
||||||
residual = residual.to(torch.float32)
|
residual = residual.to(torch.float32)
|
||||||
|
|
||||||
hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
|
hidden_states = self.mixer(
|
||||||
|
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
|
||||||
|
)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -563,7 +580,7 @@ class MambaModel(MambaPreTrainedModel):
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, MambaOutput]:
|
) -> Union[Tuple, MambaOutput]:
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
@@ -605,10 +622,15 @@ class MambaModel(MambaPreTrainedModel):
|
|||||||
for mixer_block in self.layers:
|
for mixer_block in self.layers:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
hidden_states = self._gradient_checkpointing_func(
|
hidden_states = self._gradient_checkpointing_func(
|
||||||
mixer_block.__call__, hidden_states, cache_params, cache_position
|
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
|
hidden_states = mixer_block(
|
||||||
|
hidden_states,
|
||||||
|
cache_params=cache_params,
|
||||||
|
cache_position=cache_position,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
@@ -668,6 +690,12 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
):
|
):
|
||||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
||||||
|
|
||||||
|
if "attention_mask" in model_kwargs:
|
||||||
|
attention_mask = model_kwargs["attention_mask"]
|
||||||
|
model_kwargs["attention_mask"] = torch.cat(
|
||||||
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
@@ -677,6 +705,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@@ -689,6 +718,10 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
)
|
)
|
||||||
if cache_position[0] > 0:
|
if cache_position[0] > 0:
|
||||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
|
||||||
# considering padding will be applied when input length is shorter, and truncation
|
# considering padding will be applied when input length is shorter, and truncation
|
||||||
@@ -706,6 +739,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
"cache_params": cache_params,
|
"cache_params": cache_params,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
@@ -719,6 +753,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
cache_params: Optional[MambaCache] = None,
|
cache_params: Optional[MambaCache] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
@@ -744,6 +779,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
|
|||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
hidden_states = mamba_outputs[0]
|
hidden_states = mamba_outputs[0]
|
||||||
|
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ class FalconMambaModelTester:
|
|||||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||||
):
|
):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.seq_length], 1)
|
||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
token_labels = None
|
token_labels = None
|
||||||
@@ -119,7 +120,7 @@ class FalconMambaModelTester:
|
|||||||
return (
|
return (
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
None,
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
@@ -153,6 +154,7 @@ class FalconMambaModelTester:
|
|||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
@@ -161,6 +163,7 @@ class FalconMambaModelTester:
|
|||||||
return (
|
return (
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
@@ -253,12 +256,12 @@ class FalconMambaModelTester:
|
|||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
_,
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
) = self.prepare_config_and_inputs()
|
) = self.prepare_config_and_inputs()
|
||||||
inputs_dict = {"input_ids": input_ids}
|
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -491,3 +494,33 @@ class FalconMambaIntegrationTests(unittest.TestCase):
|
|||||||
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
|
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
|
||||||
"Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep",
|
"Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_batched_generation(self):
|
||||||
|
model_id = "tiiuae/falcon-mamba-7b"
|
||||||
|
tok = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
tok.pad_token_id = tok.eos_token_id
|
||||||
|
|
||||||
|
texts = ["Hello today", "Hello my name is Younes and today"]
|
||||||
|
|
||||||
|
EXPECTED_OUTPUT = [
|
||||||
|
"Hello today I'm going to show you how to make a 3D model of a house.\n",
|
||||||
|
"Hello my name is Younes and today I will be talking about the topic of “The importance of the internet in our life”.\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
out = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
# We test the same generations with inputs_embeds
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids"))
|
||||||
|
|
||||||
|
inputs["inputs_embeds"] = inputs_embeds
|
||||||
|
out = model.generate(**inputs, max_new_tokens=20)
|
||||||
|
out = tok.batch_decode(out, skip_special_tokens=True)
|
||||||
|
|
||||||
|
self.assertListEqual(out, EXPECTED_OUTPUT)
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ class MambaModelTester:
|
|||||||
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
|
||||||
):
|
):
|
||||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
attention_mask = ids_tensor([self.batch_size, self.seq_length], 1)
|
||||||
|
|
||||||
sequence_labels = None
|
sequence_labels = None
|
||||||
token_labels = None
|
token_labels = None
|
||||||
@@ -112,7 +113,7 @@ class MambaModelTester:
|
|||||||
return (
|
return (
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
None,
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
@@ -146,6 +147,7 @@ class MambaModelTester:
|
|||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
@@ -154,6 +156,7 @@ class MambaModelTester:
|
|||||||
return (
|
return (
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
@@ -246,12 +249,12 @@ class MambaModelTester:
|
|||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
input_ids,
|
input_ids,
|
||||||
_,
|
attention_mask,
|
||||||
sequence_labels,
|
sequence_labels,
|
||||||
token_labels,
|
token_labels,
|
||||||
choice_labels,
|
choice_labels,
|
||||||
) = self.prepare_config_and_inputs()
|
) = self.prepare_config_and_inputs()
|
||||||
inputs_dict = {"input_ids": input_ids}
|
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user