From adea67541ad31f5193e56d2b18d2e7c57c4ecc15 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 10 Oct 2024 11:50:39 +0200 Subject: [PATCH] Phi3: fix attn for sliding window (#33586) * fix phi3 attn fir sliding window * fix tests * address most comment * style * update after rebase * add more models * fix tests --- src/transformers/models/mimi/modeling_mimi.py | 50 +++-- .../models/mistral/modeling_mistral.py | 180 +++++++++++++++--- .../models/mixtral/modeling_mixtral.py | 120 ++++++++++-- src/transformers/models/phi3/modeling_phi3.py | 51 +++-- .../models/phimoe/modeling_phimoe.py | 52 +++-- .../models/qwen2/modeling_qwen2.py | 123 ++++++++++-- .../models/qwen2_moe/modeling_qwen2_moe.py | 123 ++++++++++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 52 +++-- .../models/starcoder2/modeling_starcoder2.py | 123 ++++++++++-- tests/models/phi3/test_modeling_phi3.py | 90 +++++++++ tests/models/qwen2/test_modeling_qwen2.py | 2 +- .../qwen2_moe/test_modeling_qwen2_moe.py | 2 +- utils/modular_model_converter.py | 8 +- 13 files changed, 831 insertions(+), 145 deletions(-) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index b82c90cf2b..514f9de706 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -23,7 +23,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -1027,7 +1027,7 @@ class MimiTransformerModel(nn.Module): attentions=all_self_attns, ) - # Copied from transformers.models.gemma.modeling_gemma.GemmaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1046,21 +1046,30 @@ class MimiTransformerModel(nn.Module): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1077,6 +1086,8 @@ class MimiTransformerModel(nn.Module): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1088,13 +1099,12 @@ class MimiTransformerModel(nn.Module): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mimi def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1103,6 +1113,8 @@ class MimiTransformerModel(nn.Module): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: MimiConfig, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1110,13 +1122,11 @@ class MimiTransformerModel(nn.Module): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1125,6 +1135,10 @@ class MimiTransformerModel(nn.Module): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`MimiConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1134,19 +1148,27 @@ class MimiTransformerModel(nn.Module): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 70d97cb4fb..b0ffe3e56e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -856,7 +856,7 @@ class MistralModel(MistralPreTrainedModel): use_cache: bool, output_attentions: bool, ): - if self._attn_implementation == "flash_attention_2": + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: @@ -872,12 +872,11 @@ class MistralModel(MistralPreTrainedModel): # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - - # cache_position must be valid here no matter which cache we use past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) @@ -906,30 +905,18 @@ class MistralModel(MistralPreTrainedModel): else past_seen_tokens + sequence_length + 1 ) - if attention_mask is not None and attention_mask.dim() == 4: - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if self.config.sliding_window is not None: - if not using_sliding_window_cache or sequence_length > self.config.sliding_window: - exclude_mask.bitwise_or_( - torch.arange(target_length, device=device) - <= (cache_position.reshape(-1, 1) - self.config.sliding_window) - ) - causal_mask *= exclude_mask - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.dim() == 2: - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) if ( self.config._attn_implementation == "sdpa" @@ -944,6 +931,73 @@ class MistralModel(MistralPreTrainedModel): return causal_mask + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MistralConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MistralConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] @@ -1074,6 +1128,78 @@ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # `contiguous()` needed for compilation use cases + model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + @add_start_docstrings( """ diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index b7c781e80f..9c7fadbb8f 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -29,7 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1068,7 +1068,7 @@ class MixtralModel(MixtralPreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1087,21 +1087,30 @@ class MixtralModel(MixtralPreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1118,6 +1127,8 @@ class MixtralModel(MixtralPreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1129,13 +1140,12 @@ class MixtralModel(MixtralPreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mixtral def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1144,6 +1154,8 @@ class MixtralModel(MixtralPreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: MixtralConfig, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1151,13 +1163,11 @@ class MixtralModel(MixtralPreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1166,6 +1176,10 @@ class MixtralModel(MixtralPreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`MixtralConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1175,19 +1189,27 @@ class MixtralModel(MixtralPreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1344,6 +1366,76 @@ class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): router_logits=outputs.router_logits, ) + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + output_router_logits=False, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + @add_start_docstrings( """ diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 811b584e50..0380c6cd49 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -25,7 +25,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1036,7 +1036,6 @@ class Phi3Model(Phi3PreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1055,21 +1054,30 @@ class Phi3Model(Phi3PreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1086,6 +1094,8 @@ class Phi3Model(Phi3PreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1097,13 +1107,12 @@ class Phi3Model(Phi3PreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3 def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1112,6 +1121,8 @@ class Phi3Model(Phi3PreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: Phi3Config, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1119,13 +1130,11 @@ class Phi3Model(Phi3PreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1134,6 +1143,10 @@ class Phi3Model(Phi3PreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`Phi3Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1143,19 +1156,27 @@ class Phi3Model(Phi3PreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1372,6 +1393,8 @@ class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): device=device, cache_position=cache_position, batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, ) if num_logits_to_keep is not None: diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 07fba62722..d1705f04dd 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -24,7 +24,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1203,7 +1203,7 @@ class PhimoeModel(PhimoePreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1222,21 +1222,30 @@ class PhimoeModel(PhimoePreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1253,6 +1262,8 @@ class PhimoeModel(PhimoePreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1264,13 +1275,12 @@ class PhimoeModel(PhimoePreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phimoe def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1279,6 +1289,8 @@ class PhimoeModel(PhimoePreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: PhimoeConfig, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1286,13 +1298,11 @@ class PhimoeModel(PhimoePreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1301,6 +1311,10 @@ class PhimoeModel(PhimoePreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`PhimoeConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1310,19 +1324,27 @@ class PhimoeModel(PhimoePreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1561,6 +1583,8 @@ class PhimoeForCausalLM(PhimoePreTrainedModel, GenerationMixin): device=device, cache_position=cache_position, batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, ) if num_logits_to_keep is not None: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8f1226cdd8..50f273ba76 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -28,7 +28,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -954,7 +954,7 @@ class Qwen2Model(Qwen2PreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -973,21 +973,30 @@ class Qwen2Model(Qwen2PreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1004,6 +1013,8 @@ class Qwen2Model(Qwen2PreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1015,13 +1026,12 @@ class Qwen2Model(Qwen2PreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2 def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1030,6 +1040,8 @@ class Qwen2Model(Qwen2PreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: Qwen2Config, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1037,13 +1049,11 @@ class Qwen2Model(Qwen2PreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1052,6 +1062,10 @@ class Qwen2Model(Qwen2PreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`Qwen2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1061,19 +1075,27 @@ class Qwen2Model(Qwen2PreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1206,6 +1228,79 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) + # Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # `contiguous()` needed for compilation use cases + model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + @add_start_docstrings( """ diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 4e2605a6a5..2ab13b7227 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -29,7 +29,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1135,7 +1135,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): router_logits=all_router_logits, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1154,21 +1154,30 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1185,6 +1194,8 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1196,13 +1207,12 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2Moe def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1211,6 +1221,8 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: Qwen2MoeConfig, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1218,13 +1230,11 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1233,6 +1243,10 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`Qwen2MoeConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1242,19 +1256,27 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1410,6 +1432,79 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): router_logits=outputs.router_logits, ) + # Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # `contiguous()` needed for compilation use cases + model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + @add_start_docstrings( """ diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index d26e32332e..283e38d3a7 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -30,7 +30,7 @@ import torch.utils.checkpoint from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1217,7 +1217,7 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1236,21 +1236,30 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -1267,6 +1276,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -1278,13 +1289,12 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2VL def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1293,6 +1303,8 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: Qwen2VLConfig, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1300,13 +1312,11 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1315,6 +1325,10 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`Qwen2VLConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1324,19 +1338,27 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1820,6 +1842,8 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): device=device, cache_position=cache_position, batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, ) model_inputs.update( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index f7f5d7d188..e0fdbef1a3 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -28,7 +28,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -928,7 +928,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): attentions=all_self_attns, ) - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -947,21 +947,30 @@ class Starcoder2Model(Starcoder2PreTrainedModel): # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -978,6 +987,8 @@ class Starcoder2Model(Starcoder2PreTrainedModel): device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -989,13 +1000,12 @@ class Starcoder2Model(Starcoder2PreTrainedModel): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Starcoder2 def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1004,6 +1014,8 @@ class Starcoder2Model(Starcoder2PreTrainedModel): device: torch.device, cache_position: torch.Tensor, batch_size: int, + config: Starcoder2Config, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -1011,13 +1023,11 @@ class Starcoder2Model(Starcoder2PreTrainedModel): Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1026,6 +1036,10 @@ class Starcoder2Model(Starcoder2PreTrainedModel): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`Starcoder2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -1035,19 +1049,27 @@ class Starcoder2Model(Starcoder2PreTrainedModel): causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask |= sliding_attend_mask + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask @@ -1182,6 +1204,79 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): attentions=outputs.attentions, ) + # Copied from transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # `contiguous()` needed for compilation use cases + model_inputs = {"input_ids": input_ids.contiguous(), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_cache_shape(), + dtype=self.lm_head.weight.dtype, + device=device, + cache_position=cache_position, + batch_size=batch_size, + config=self.config, + past_key_values=past_key_values, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + @add_start_docstrings( """ diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index 2ecf0bb3be..2c5557dfd6 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -615,3 +615,93 @@ class Phi3IntegrationTest(unittest.TestCase): ] self.assertListEqual(output_text, EXPECTED_OUTPUT) + + def test_phi3_mini_4k_sliding_window(self): + """ + This tests that Phi3 doesn't deteriorate in quality for long context generations. Since Phi3 has + sliding window attention, the test is tailored so that (context + max_new_tokens > sliding_window). + See #33586 for more + """ + model = Phi3ForCausalLM.from_pretrained( + "microsoft/Phi-3-mini-4k-instruct", device_map=torch_device, torch_dtype=torch.bfloat16 + ) + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + input_text = """ + <|user|> + Tell me about Paris, France.<|end|> + <|assistant|> + Paris, the capital city of France, is renowned for its rich history, iconic landmarks, and vibrant culture. Known as "The City of Light," Paris is situated in the north-central part of the country along the Seine River. + + Here are some key aspects of Paris: + + 1. Landmarks: Paris is home to numerous famous landmarks, including the Eiffel Tower, the Louvre Museum, Notre-Dame Cathedral, and the Champs-Élysées. The Eiffel Tower, built in 1889, is an iconic symbol of Paris and attracts millions of tourists each year. The Louvre Museum, the world's largest art museum, houses thousands of works of art, including the Mona Lisa and the Venus de Milo. + + 2. History: Paris has a rich history dating back to the 3rd century BC, when it was founded by a Celtic tribe called the Parisii. Over the centuries, the city has been influenced by various cultures, including the Romans, the Franks, and the Normans. The French Revolution in the late 18th century marked a significant turning point in Paris's history, leading to the establishment of the modern French Republic. + + 3. Culture: Paris is a global center for art, fashion, gastronomy, and culture. The city is home to numerous museums, including the Centre Pompidou, Musée d'Orsay, and Musée Rodin. Paris is also known for its fashion industry, with many famous designers having their origins in the city. The city's cuisine is also highly regarded, with a focus on fresh ingredients, and a wide variety of dishes, including French classics like coq au vin, boeuf bourguignon, and crêpes. + + 4. Architecture: Parisian architecture is characterized by its diverse styles, ranging from Gothic and Romanesque to Art Nouveau and Art Deco. The city's famous Haussmannian buildings, designed by Baron Haussmann in the mid-19th century, are known for their uniform facades, wrought-iron balconies, and large windows. + + 5. Transportation: Paris has an extensive public transportation system, including the Paris Métro, RER (suburban trains), and buses. The city's iconic yellow taxis are also a popular mode of transportation. + + 6. Language: The official language of Paris is French, and the city's residents are known for their charm and politeness. + + 7. Festivals and Events: Paris hosts numerous festivals and events throughout the year, including the annual Bastille Day celebrations, the Paris Fashion Week, and the famous annual New Year's Eve fireworks on the Eiffel Tower. + + 8. Geography: Paris is located in the north-central part of France, with the Seine River running through the city. The city's geography is characterized by rolling hills and picturesque parks, such as the Bois de Boulogne and the Jardin des Tuileries. + + 9. Population: As of 2021, Paris has an estimated population of around 2.2 million residents, with the metropolitan area housing over 12 million people. + + In summary, Paris is a city steeped in history, culture, and art, with a unique blend of architectural styles and a vibrant atmosphere that continues to captivate millions of visitors each year.<|end|> + <|user|> + Please give me a list of 5 architectural landmarks in Paris, France.<|end|> + <|assistant|> + 1. Eiffel Tower: Designed by Gustave Eiffel and completed in 1889, the Eiffel Tower is an iconic symbol of Paris and France. Standing at 324 meters tall, it was the tallest man-made structure in the world until the completion of the Chrysler Building in New York in 1930. The Eiffel Tower is made of wrought iron and offers visitors stunning views of the city from its three levels. + + 2. Notre-Dame Cathedral: Located on the Île de la Cité, Notre-Dame Cathedral is a masterpiece of French Gothic architecture. Construction began in the 12th century and continued for over 200 years, with the cathedral's completion in the 14th century. The cathedral is famous for its intricate facade, stained-glass windows, and the iconic gargoyles and chimeras. + + 3. Louvre Museum: Originally built as a fortress in the 12th century, the Louvre Museum is now the world's largest art museum and a historic monument in Paris. The museum's most famous landmark is the iconic glass pyramid entrance, designed by architect I. M. Pei in the 1980s. The Louvre houses over 380,000 works of art, including the Mona Lisa and the Venus de Milo. + + 4. Sacré-Cœur Basilica: The Sacré-Cœur Basilica, also known as the Basilique du Sacré-Cœur, is a Roman Catholic church and minor basilica located at the summit of the butte Montmartre, the highest point in Paris. The basilica was designed by Paul Abadie and dedicated in 1914. Its white domes and lavender-colored travertine stone make it a distinctive landmark in the Paris skyline. + + 5. Arc de Triomphe: The Arc de Triomphe is a monumental structure located at the western end of the Champs-Élysées. Commissioned by Napoleon in 1806, the Arc was designed by Jean-François-Thérèse Chalgrin and completed in 1836. The monument honors those who fought and died for France during the French Revolutionary and Napoleonic Wars. The Arc features sculptural reliefs and inscriptions, and its façade is adorned with the names of 357 generals and 660 soldiers. + + These five architectural landmarks showcase the diverse styles and historical periods of Paris, from Gothic to Neoclassical, and from the 19th to the 20th centuries. Each landmark has its unique features and contributes to the city's rich architectural heritage.<|end|> + <|user|> + Please give me a list of 10 famous items displayed in the Louvre Museum. Thanks!<|end|> + <|assistant|> + 1. Mona Lisa: The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is arguably the most famous painting in the world. The portrait is known for its enigmatic smile and masterful use of sfumato, a technique that creates a soft, hazy effect. + + 2. Venus de Milo: This ancient Greek statue, believed to have been created around 130-100 BC, is a masterpiece of Hellenistic sculpture. The Venus de Milo is renowned for its graceful beauty and the mystery surrounding its missing arms. + + 3. Winged Victory of Samothrace: This Hellenistic sculpture, dating back to the 2nd century BC, depicts the Greek goddess Nike, the personification of victory. The sculpture is celebrated for its dynamic movement and intricate details. + + 4. Liberty Leading the People: This iconic painting by Eugène Delacroix, created in 1830, commemorates the July Revolution in France. The artwork depicts a woman personifying Liberty leading a group of revolutionaries over the bodies of the fallen. + + 5. The Wedding at Cana: A 1516 painting by Veronese, The Wedding at Cana is a large-scale work that depicts the biblical story of Jesus turning water into wine at a wedding feast. The painting is known for its vibrant colors and intricate details. + + 6. The Raft of the Medusa: This 1819 painting by Théodore Géricault is a powerful depiction of the aftermath of the shipwreck of the French frigate Méduse. The painting is famous for its dramatic composition and emotional intensity. + + 7. The Coronation of Napoleon: This 1805 painting by Jacques-Louis David portrays the coronation of Napoleon Bonaparte as Emperor of the French. The artwork is a masterpiece of neoclassical style and captures the pomp and ceremony of the event. + + 8. The Death of Socrates: A 1914 painting by Pablo Picasso, The Death of Socrates depicts the moment when the ancient Greek philosopher Socrates drinks the poison hemlock, as ordered by the Athenian government. The painting is a powerful expression of the existential themes associated with Socrates' death. + + 9. The Turkish Bath: A 1887 painting by Puvis de Chavannes, The Turkish Bath is a large-scale work that depicts a scene of relaxation and leisity in a traditional Turkish bath. The painting is known for its subdued color palette and serene atmosphere. + + 10. The Wedding at Cana: A 1901 painting by Giovanni Boldini, The Wedding at Cana is a lively depiction of a festive Italian wedding. The painting is characterized by Boldini's signature flamboyant style and his ability to capture the essence of the subjects with a sense of movement and energy. + + These 10 famous items in the Louvre Museum showcase a diverse range of artistic styles and periods, from ancient Greek and Roman sculptures to Renaissance and 19th-century paintings. The collection at the Louvre is one of the world's largest and most important museums, housing over 800,000 works of art, and these famous items are just a small selection of the museum's vast collection.<|end|> + <|user|> + Please give me a list of 10 famous French dishes.<|end|> + <|assistant|> + """ + + inputs = tokenizer(input_text, return_tensors="pt").to(device=torch_device) + outputs = model.generate(**inputs, max_new_tokens=100) + output_text = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) + EXPECTED_OUTPUT = [ + '1. Coq au Vin: Coq au Vin is a classic French dish that translates to "rooster in wine." The dish consists of chicken braised with wine, lardons, mushrooms, and garlic. It is a hearty and flavorful dish that is often served with potatoes or rice.\n\n 2. Boeuf Bourguignon: Boeuf Bourguignon is a traditional French beef stew that' + ] + + self.assertListEqual(output_text, EXPECTED_OUTPUT) diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 5bcf08258a..c7fe657798 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -64,7 +64,7 @@ class Qwen2ModelTester: num_hidden_layers=5, max_window_layers=3, use_sliding_window=True, - sliding_window=2, + sliding_window=50, num_attention_heads=4, num_key_value_heads=2, intermediate_size=37, diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index d2292f1ea6..11fc55f6ba 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -64,7 +64,7 @@ class Qwen2MoeModelTester: num_hidden_layers=5, max_window_layers=3, use_sliding_window=True, - sliding_window=2, + sliding_window=50, num_attention_heads=4, num_key_value_heads=2, intermediate_size=37, diff --git a/utils/modular_model_converter.py b/utils/modular_model_converter.py index 599dc70e17..0ae641d4e5 100644 --- a/utils/modular_model_converter.py +++ b/utils/modular_model_converter.py @@ -18,7 +18,7 @@ import importlib import os import re from collections import defaultdict, deque -from typing import Dict, List, Set +from typing import Dict, List, Optional, Set import libcst as cst from check_copies import run_ruff @@ -623,7 +623,7 @@ def get_new_part(class_name, base_class): return snake_case -def find_all_dependencies(function: str, dependency_mapping: dict[str, set]): +def find_all_dependencies(function: str, dependency_mapping: Dict[str, set]): """Return all the dependencies of the given top-level function. Given the following structure in the `modular_xxx.py` file: ``` def foo1(): @@ -1001,8 +1001,8 @@ class ModularConverterTransformer(CSTTransformer): top_level_function: str, body: dict, function_node: cst.FunctionDef, - matching_callers: set | None = None, - parent: str | None = None, + matching_callers: Optional[set] = None, + parent: Optional[str] = None, ) -> bool: """Check if the `top_level_function` should be added to the body (i.e. it is not already present, and `matching_callers` is not empy, or `parent`is provided). If it should be added, do it (in the correct location, just before its caller) and return