Mamba / FalconMamba: Fix mamba left padding (#32677)

* fix mamba left padding

* Apply suggestions from code review

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>

* fix copies

* test with `inputs_embeds`

* Update src/transformers/models/falcon_mamba/modeling_falcon_mamba.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* copies

* clairfy

* fix last comments

* remove

---------

Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Younes Belkada
2024-08-19 18:01:35 +04:00
committed by GitHub
parent 59e8f1919c
commit 93e538ae2e
4 changed files with 129 additions and 22 deletions

View File

@@ -155,6 +155,7 @@ class FalconMambaMixer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2) projected_states = self.in_proj(hidden_states).transpose(1, 2)
@@ -179,6 +180,9 @@ class FalconMambaMixer(nn.Module):
else: else:
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0: if cache_params is not None and cache_position[0] > 0:
@@ -200,6 +204,9 @@ class FalconMambaMixer(nn.Module):
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
) )
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C # 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
@@ -259,6 +266,7 @@ class FalconMambaMixer(nn.Module):
input_states, input_states,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
batch_size, seq_len, _ = input_states.shape batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype dtype = input_states.dtype
@@ -266,6 +274,9 @@ class FalconMambaMixer(nn.Module):
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
if cache_params is not None: if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone() ssm_state = cache_params.ssm_states[self.layer_idx].clone()
@@ -294,6 +305,9 @@ class FalconMambaMixer(nn.Module):
) )
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
@@ -355,10 +369,11 @@ class FalconMambaMixer(nn.Module):
hidden_states, hidden_states,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position) return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
# Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba # Copied from transformers.models.mamba.modeling_mamba.MambaRMSNorm with Mamba->FalconMamba
@@ -396,13 +411,16 @@ class FalconMambaBlock(nn.Module):
hidden_states, hidden_states,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32: if self.residual_in_fp32:
residual = residual.to(torch.float32) residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return hidden_states return hidden_states
@@ -601,14 +619,13 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored arg
inputs_embeds: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it attention_mask: Optional[torch.LongTensor] = None,
) -> Union[Tuple, FalconMambaOutput]: ) -> Union[Tuple, FalconMambaOutput]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -649,10 +666,15 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
for mixer_block in self.layers: for mixer_block in self.layers:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
) )
else: else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@@ -712,6 +734,13 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
and model_kwargs["cache_position"] is not None and model_kwargs["cache_position"] is not None
): ):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs return model_kwargs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
@@ -721,6 +750,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
use_cache=None, use_cache=None,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
): ):
if use_cache: if use_cache:
@@ -733,6 +763,10 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
) )
if cache_position[0] > 0: if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = None
else: else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage # we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation # considering padding will be applied when input length is shorter, and truncation
@@ -750,6 +784,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
"cache_params": cache_params, "cache_params": cache_params,
"use_cache": use_cache, "use_cache": use_cache,
"cache_position": cache_position, "cache_position": cache_position,
"attention_mask": attention_mask,
} }
) )
return model_inputs return model_inputs
@@ -760,11 +795,10 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
output_type=FalconMambaCausalLMOutput, output_type=FalconMambaCausalLMOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
# Ignore copy
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # Ignored copy attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@@ -790,6 +824,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
attention_mask=attention_mask,
) )
hidden_states = falcon_mamba_outputs[0] hidden_states = falcon_mamba_outputs[0]

View File

@@ -136,6 +136,7 @@ class MambaMixer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states).transpose(1, 2) projected_states = self.in_proj(hidden_states).transpose(1, 2)
@@ -160,6 +161,9 @@ class MambaMixer(nn.Module):
else: else:
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if cache_params is not None and cache_position[0] > 0: if cache_params is not None and cache_position[0] > 0:
@@ -181,6 +185,9 @@ class MambaMixer(nn.Module):
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
) )
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C # 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
@@ -226,13 +233,16 @@ class MambaMixer(nn.Module):
return contextualized_states return contextualized_states
# fmt: off # fmt: off
def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None): def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
batch_size, seq_len, _ = input_states.shape batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype dtype = input_states.dtype
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
hidden_states, gate = projected_states.chunk(2, dim=1) hidden_states, gate = projected_states.chunk(2, dim=1)
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
if cache_params is not None: if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone() ssm_state = cache_params.ssm_states[self.layer_idx].clone()
@@ -261,6 +271,9 @@ class MambaMixer(nn.Module):
) )
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
@@ -306,10 +319,11 @@ class MambaMixer(nn.Module):
hidden_states, hidden_states,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling(): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position) return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.slow_forward(hidden_states, cache_params, cache_position) return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
class MambaRMSNorm(nn.Module): class MambaRMSNorm(nn.Module):
@@ -346,13 +360,16 @@ class MambaBlock(nn.Module):
hidden_states, hidden_states,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32: if self.residual_in_fp32:
residual = residual.to(torch.float32) residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
return hidden_states return hidden_states
@@ -563,7 +580,7 @@ class MambaModel(MambaPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it attention_mask: Optional[torch.LongTensor] = None,
) -> Union[Tuple, MambaOutput]: ) -> Union[Tuple, MambaOutput]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -605,10 +622,15 @@ class MambaModel(MambaPreTrainedModel):
for mixer_block in self.layers: for mixer_block in self.layers:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func( hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
) )
else: else:
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position) hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@@ -668,6 +690,12 @@ class MambaForCausalLM(MambaPreTrainedModel):
): ):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
return model_kwargs return model_kwargs
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
@@ -677,6 +705,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
use_cache=None, use_cache=None,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
): ):
if use_cache: if use_cache:
@@ -689,6 +718,10 @@ class MambaForCausalLM(MambaPreTrainedModel):
) )
if cache_position[0] > 0: if cache_position[0] > 0:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = None
else: else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage # we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation # considering padding will be applied when input length is shorter, and truncation
@@ -706,6 +739,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
"cache_params": cache_params, "cache_params": cache_params,
"use_cache": use_cache, "use_cache": use_cache,
"cache_position": cache_position, "cache_position": cache_position,
"attention_mask": attention_mask,
} }
) )
return model_inputs return model_inputs
@@ -719,6 +753,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
def forward( def forward(
self, self,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[MambaCache] = None, cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
@@ -744,6 +779,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
use_cache=use_cache, use_cache=use_cache,
cache_position=cache_position, cache_position=cache_position,
attention_mask=attention_mask,
) )
hidden_states = mamba_outputs[0] hidden_states = mamba_outputs[0]

View File

@@ -101,6 +101,7 @@ class FalconMambaModelTester:
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
): ):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = ids_tensor([self.batch_size, self.seq_length], 1)
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
@@ -119,7 +120,7 @@ class FalconMambaModelTester:
return ( return (
config, config,
input_ids, input_ids,
None, attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
@@ -153,6 +154,7 @@ class FalconMambaModelTester:
( (
config, config,
input_ids, input_ids,
attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
@@ -161,6 +163,7 @@ class FalconMambaModelTester:
return ( return (
config, config,
input_ids, input_ids,
attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
@@ -253,12 +256,12 @@ class FalconMambaModelTester:
( (
config, config,
input_ids, input_ids,
_, attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
) = self.prepare_config_and_inputs() ) = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids} inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict return config, inputs_dict
@@ -491,3 +494,33 @@ class FalconMambaIntegrationTests(unittest.TestCase):
self.tokenizer.batch_decode(out, skip_special_tokens=False)[0], self.tokenizer.batch_decode(out, skip_special_tokens=False)[0],
"Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep", "Hello today I am going to show you how to make a simple and easy to make paper plane.\nStep",
) )
def test_batched_generation(self):
model_id = "tiiuae/falcon-mamba-7b"
tok = AutoTokenizer.from_pretrained(model_id)
tok.pad_token_id = tok.eos_token_id
texts = ["Hello today", "Hello my name is Younes and today"]
EXPECTED_OUTPUT = [
"Hello today I'm going to show you how to make a 3D model of a house.\n",
"Hello my name is Younes and today I will be talking about the topic of “The importance of the internet in our life”.\n",
]
inputs = tok(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(torch_device)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=0, torch_dtype=torch.bfloat16)
out = model.generate(**inputs, max_new_tokens=20)
out = tok.batch_decode(out, skip_special_tokens=True)
self.assertListEqual(out, EXPECTED_OUTPUT)
# We test the same generations with inputs_embeds
with torch.no_grad():
inputs_embeds = model.get_input_embeddings()(inputs.pop("input_ids"))
inputs["inputs_embeds"] = inputs_embeds
out = model.generate(**inputs, max_new_tokens=20)
out = tok.batch_decode(out, skip_special_tokens=True)
self.assertListEqual(out, EXPECTED_OUTPUT)

View File

@@ -94,6 +94,7 @@ class MambaModelTester:
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
): ):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
attention_mask = ids_tensor([self.batch_size, self.seq_length], 1)
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
@@ -112,7 +113,7 @@ class MambaModelTester:
return ( return (
config, config,
input_ids, input_ids,
None, attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
@@ -146,6 +147,7 @@ class MambaModelTester:
( (
config, config,
input_ids, input_ids,
attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
@@ -154,6 +156,7 @@ class MambaModelTester:
return ( return (
config, config,
input_ids, input_ids,
attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
@@ -246,12 +249,12 @@ class MambaModelTester:
( (
config, config,
input_ids, input_ids,
_, attention_mask,
sequence_labels, sequence_labels,
token_labels, token_labels,
choice_labels, choice_labels,
) = self.prepare_config_and_inputs() ) = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids} inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict return config, inputs_dict