From 0cf27916f09a1a99af55ef4f2f3e8675372f38b6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 4 Jul 2025 09:01:56 +0200 Subject: [PATCH] Add packed tensor format support for flex/sdpa/eager through the mask! (#39194) * Add the necesary logic to mask_utils * add it everywhere * Update masking_utils.py * style * Update masking_utils.py * Update modeling_mimi.py * Update masking_utils.py * add support for more than batch size 1 * Update masking_utils.py * add test * style * Update test_masking_utils.py * Update masking_utils.py * add require_token * fix tests * fix --- src/transformers/generation/utils.py | 2 + src/transformers/masking_utils.py | 104 ++++++++++++-- .../models/arcee/modeling_arcee.py | 1 + src/transformers/models/aria/modeling_aria.py | 1 + .../models/bitnet/modeling_bitnet.py | 1 + .../models/cohere/modeling_cohere.py | 1 + .../models/cohere2/modeling_cohere2.py | 1 + .../models/cohere2/modular_cohere2.py | 1 + src/transformers/models/csm/modeling_csm.py | 2 + src/transformers/models/csm/modular_csm.py | 1 + .../deepseek_v3/modeling_deepseek_v3.py | 1 + src/transformers/models/dia/modeling_dia.py | 1 + src/transformers/models/dia/modular_dia.py | 1 + .../models/diffllama/modeling_diffllama.py | 1 + .../models/dots1/modeling_dots1.py | 1 + src/transformers/models/emu3/modeling_emu3.py | 1 + .../models/gemma/modeling_gemma.py | 1 + .../models/gemma/modular_gemma.py | 1 + .../models/gemma2/modeling_gemma2.py | 1 + .../models/gemma2/modular_gemma2.py | 1 + .../models/gemma3/modeling_gemma3.py | 4 + .../models/gemma3/modular_gemma3.py | 4 + .../models/gemma3n/modeling_gemma3n.py | 1 + .../models/gemma3n/modular_gemma3n.py | 1 + src/transformers/models/glm/modeling_glm.py | 1 + src/transformers/models/glm4/modeling_glm4.py | 1 + .../models/glm4v/modeling_glm4v.py | 1 + .../models/glm4v/modular_glm4v.py | 1 + .../models/gpt_neox/modeling_gpt_neox.py | 1 + .../models/gpt_neox/modular_gpt_neox.py | 1 + .../models/granite/modeling_granite.py | 1 + .../models/granite/modular_granite.py | 1 + .../models/helium/modeling_helium.py | 1 + .../models/llama/modeling_llama.py | 1 + .../models/llama4/modeling_llama4.py | 1 + src/transformers/models/mimi/modeling_mimi.py | 1 + .../models/minimax/modeling_minimax.py | 1 + .../models/minimax/modular_minimax.py | 1 + .../models/mistral/modeling_mistral.py | 1 + .../models/mistral/modular_mistral.py | 1 + .../models/mixtral/modeling_mixtral.py | 1 + .../models/mixtral/modular_mixtral.py | 1 + .../models/moonshine/modeling_moonshine.py | 1 + .../models/moonshine/modular_moonshine.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + .../models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi/modular_phi.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + .../modeling_phi4_multimodal.py | 1 + .../modular_phi4_multimodal.py | 1 + .../models/qwen2/modeling_qwen2.py | 1 + .../models/qwen2/modular_qwen2.py | 1 + .../qwen2_5_omni/modeling_qwen2_5_omni.py | 2 + .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 + .../models/qwen2_vl/modeling_qwen2_vl.py | 1 + .../models/qwen3/modeling_qwen3.py | 1 + .../models/qwen3_moe/modeling_qwen3_moe.py | 1 + .../models/smollm3/modeling_smollm3.py | 1 + .../models/starcoder2/modeling_starcoder2.py | 1 + .../models/starcoder2/modular_starcoder2.py | 1 + .../models/t5gemma/modeling_t5gemma.py | 3 + .../models/t5gemma/modular_t5gemma.py | 3 + .../models/whisper/modeling_whisper.py | 1 + tests/utils/test_masking_utils.py | 132 ++++++++++++++++++ 65 files changed, 303 insertions(+), 9 deletions(-) create mode 100644 tests/utils/test_masking_utils.py diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a03693922f..e36417269d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -656,6 +656,7 @@ class GenerationMixin(ContinuousMixin): # If it's not defined, it means the model uses the new general mask API if causal_mask_creation_function is None: # can't be found token_type_ids = getattr(model_input, "token_type_ids", None) + position_ids = getattr(model_input, position_ids_key, None) # Some models may overwrite the general one causal_mask_creation_function = getattr(self, "create_masks_for_generate", create_masks_for_generate) attention_mask = causal_mask_creation_function( @@ -665,6 +666,7 @@ class GenerationMixin(ContinuousMixin): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, token_type_ids=token_type_ids, ) else: diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 128abd56ff..8d5aab9f13 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -112,6 +112,10 @@ def chunked_causal_mask_function(chunk_size: int) -> Callable: def padding_mask_function(padding_mask: torch.Tensor) -> Callable: + """ + This return the mask_function function corresponding to a 2D padding mask. + """ + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: # Note that here the mask should ALWAYS be at least of the max `kv_index` size in the dimension 1. This is because # we cannot pad it here in the mask_function as we don't know the final size, and we cannot try/except, as it is not @@ -121,6 +125,17 @@ def padding_mask_function(padding_mask: torch.Tensor) -> Callable: return inner_mask +def packed_sequence_mask_function(packed_sequence_mask: torch.Tensor) -> Callable: + """ + This return the mask_function function corresponding to a 2D packed sequence mask. + """ + + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return packed_sequence_mask[batch_idx, q_idx] == packed_sequence_mask[batch_idx, kv_idx] + + return inner_mask + + def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offset: int) -> Callable: """ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths, @@ -592,12 +607,40 @@ class AttentionMaskInterface(GeneralInterface): ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface() +def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]: + """ + Find the indices of the sequence to which each new query token in the sequence belongs when using packed + tensor format (i.e. several sequences packed in the same batch dimension). + + Args: + position_ids (`torch.Tensor`) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. + + Returns: + A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we + pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]]. + """ + # What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So + # taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result + # gives exactly the sequence indices + # Note that we assume that a single sequence cannot span several batch dimensions, i.e. 1 single sequence + # cannot be part of the end of the first batch dim and the start of the 2nd one for example + first_dummy_value = position_ids[:, :1] - 1 # We just need the diff on this first value to be 1 + position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1) + packed_sequence_mask = (position_diff != 1).cumsum(-1) + + # Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0` + # but it causes issues with export + return packed_sequence_mask + + def _preprocess_mask_arguments( config: PretrainedConfig, input_embeds: torch.Tensor, attention_mask: Optional[Union[torch.Tensor, BlockMask]], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], layer_idx: Optional[int], ) -> tuple[bool, Optional[Union[torch.Tensor, BlockMask]], int, int]: """ @@ -617,6 +660,8 @@ def _preprocess_mask_arguments( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`torch.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. layer_idx (`int`, optional): If `past_key_values` is not None, this is the layer index of the cache from which to get the key-value length and offset. Indeed, for hybrid caches, different layers may return different lengths. @@ -626,6 +671,9 @@ def _preprocess_mask_arguments( Whether we should early exit mask creation, and return the mask as-is. attention_mask (`torch.Tensor` or `BlockMask` or `None`): The attention mask to either return immediately, or to use in downstream mask creation. + packed_sequence_mask (`torch.Tensor`, optional): + In case we detected packed sequence format, this is a tensor where each similar integer indicates that + the tokens belong to the same sequence. kv_length (`int`): The size that the key and value states will have during the attention computation. kv_offset (`int`): @@ -633,7 +681,7 @@ def _preprocess_mask_arguments( """ # If the mask is already 4D, simply return as-is (it was already prepared, or it is custom) if isinstance(attention_mask, (torch.Tensor, BlockMask)) and len(attention_mask.shape) == 4: - return True, attention_mask, None, None + return True, attention_mask, None, None, None # For TGI/vLLM backends, or other custom attention without equivalent mask creation: we don't need a mask! # Note: it's not ideal to check the `_global_mapping` attribute instead of the object itself, however otherwise @@ -641,7 +689,7 @@ def _preprocess_mask_arguments( # with `torch._dynamo.exc.Unsupported: 'inline in skipfiles:Mapping.__contains__ | __contains__, skipped # according trace_rules.lookup SKIP_DIRS'` -- can be removed when we require Python>=3.11 if config._attn_implementation not in ALL_MASK_ATTENTION_FUNCTIONS._global_mapping: - return True, None, None, None + return True, None, None, None, None # Move the mask to correct device, and potentially switch dtype for efficiency if attention_mask is not None and attention_mask.ndim == 2: @@ -654,7 +702,17 @@ def _preprocess_mask_arguments( else: kv_length, kv_offset = input_embeds.shape[1], 0 - return False, attention_mask, kv_length, kv_offset + # We check the position_ids for potential packed sequence format (only if the 2D attention mask is explicitly None, + # and we don't have past_key_values, i.e. generally a training setup) + packed_sequence_mask = None + if position_ids is not None and attention_mask is None and past_key_values is None: + batch_size = input_embeds.shape[0] + # The position ids are sometimes just unsqueezed, without being expanded + if batch_size != position_ids.shape[0]: + position_ids = position_ids.expand(batch_size, -1) + packed_sequence_mask = find_packed_sequence_indices(position_ids) + + return False, attention_mask, packed_sequence_mask, kv_length, kv_offset def create_causal_mask( @@ -663,6 +721,7 @@ def create_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, BlockMask]]: @@ -684,6 +743,8 @@ def create_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`torch.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. or_mask_function (`Callable`, optional): An optional mask function to combine with the causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -697,8 +758,8 @@ def create_causal_mask( else: layer_idx = 0 - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx ) if early_exit: return attention_mask @@ -711,6 +772,11 @@ def create_causal_mask( # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + # If we detected packing format + if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) + allow_is_causal_skip = False + # Allow slight deviations from causal mask if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: @@ -744,6 +810,7 @@ def create_sliding_window_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, BlockMask]]: @@ -766,6 +833,8 @@ def create_sliding_window_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`torch.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. or_mask_function (`Callable`, optional): An optional mask function to combine with the sliding causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the sliding causal one, for example for image tokens handling. @@ -779,8 +848,8 @@ def create_sliding_window_causal_mask( else: layer_idx = 0 - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx ) if early_exit: return attention_mask @@ -797,6 +866,11 @@ def create_sliding_window_causal_mask( # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + # If we detected packing format + if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) + allow_is_causal_skip = False + # Allow slight deviations from sliding causal mask if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: @@ -831,6 +905,7 @@ def create_chunked_causal_mask( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, ) -> Optional[Union[torch.Tensor, BlockMask]]: @@ -853,6 +928,8 @@ def create_chunked_causal_mask( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`torch.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. or_mask_function (`Callable`, optional): An optional mask function to combine with the chunked causal mask function (by doing the union of both). This is useful to easily overlay another mask on top of the chunked causal one, for example for image tokens handling. @@ -866,8 +943,8 @@ def create_chunked_causal_mask( else: layer_idx = 0 - early_exit, attention_mask, kv_length, kv_offset = _preprocess_mask_arguments( - config, input_embeds, attention_mask, cache_position, past_key_values, layer_idx + early_exit, attention_mask, packed_sequence_mask, kv_length, kv_offset = _preprocess_mask_arguments( + config, input_embeds, attention_mask, cache_position, past_key_values, position_ids, layer_idx ) if early_exit: return attention_mask @@ -891,6 +968,11 @@ def create_chunked_causal_mask( # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True + # If we detected packing format + if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) + allow_is_causal_skip = False + # Allow slight deviations from chunked causal mask if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: @@ -932,6 +1014,7 @@ def create_masks_for_generate( attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], or_mask_function: Optional[Callable] = None, and_mask_function: Optional[Callable] = None, **kwargs, @@ -953,6 +1036,8 @@ def create_masks_for_generate( A tensor of shape (query_length,) indicating the current indices of the input sequence elements. past_key_values (`Cache`, optional): The past key values, if we use a cache. + position_ids (`torch.Tensor`, optional) + A 2D tensor of shape (batch_size, query_length) indicating the positions of each token in the sequences. or_mask_function (`Callable`, optional): An optional mask function to combine with the other mask function (by doing the union of both). This is useful to easily overlay another mask on top of the causal one, for example for image tokens handling. @@ -969,6 +1054,7 @@ def create_masks_for_generate( "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, "or_mask_function": or_mask_function, "and_mask_function": and_mask_function, } diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index da23391812..b1b58667a0 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -423,6 +423,7 @@ class ArceeModel(ArceePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 00b912c5b2..af2b88ca72 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -806,6 +806,7 @@ class AriaTextModel(AriaTextPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index afafd3f911..48a804b0a7 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -420,6 +420,7 @@ class BitNetModel(BitNetPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index ad1604bed4..19a140ae81 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -457,6 +457,7 @@ class CohereModel(CoherePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 3fec29e976..afcaee5c2f 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -434,6 +434,7 @@ class Cohere2Model(Cohere2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index b32e4d94dd..fc4f24b834 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -455,6 +455,7 @@ class Cohere2Model(Gemma2Model): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 3b96c272c0..de76c03616 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -500,6 +500,7 @@ class CsmDepthDecoderModel(CsmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds @@ -811,6 +812,7 @@ class CsmBackboneModel(CsmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index c1a930aa70..1f6627bef5 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -238,6 +238,7 @@ class CsmDepthDecoderModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 541ae6669e..4287e44a7f 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -610,6 +610,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index b87712852f..9e317e029c 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -629,6 +629,7 @@ class DiaDecoder(DiaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, diff --git a/src/transformers/models/dia/modular_dia.py b/src/transformers/models/dia/modular_dia.py index fe437fde84..2225788f61 100644 --- a/src/transformers/models/dia/modular_dia.py +++ b/src/transformers/models/dia/modular_dia.py @@ -455,6 +455,7 @@ class DiaDecoder(DiaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 383c329c99..06ec77b1bf 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -697,6 +697,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 58b805cca6..e0c2f5ce51 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -532,6 +532,7 @@ class Dots1Model(Dots1PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 422838603c..6d3ab0402f 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1264,6 +1264,7 @@ class Emu3TextModel(Emu3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 399c809d12..906b29ea0d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -416,6 +416,7 @@ class GemmaModel(GemmaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) # embed positions diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index f20f2b7c99..d2361ab114 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -418,6 +418,7 @@ class GemmaModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) # embed positions diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index bfd3317946..15a502b264 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -440,6 +440,7 @@ class Gemma2Model(Gemma2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index b317936c77..c5f8d63975 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -422,6 +422,7 @@ class Gemma2Model(GemmaModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 2d86e9f04c..eddea94b91 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -540,6 +540,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { @@ -923,6 +924,7 @@ class Gemma3Model(Gemma3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` @@ -1182,6 +1184,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1192,6 +1195,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 0b0960a6a9..1fa0ee273a 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -607,6 +607,7 @@ class Gemma3TextModel(Gemma2Model): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { @@ -825,6 +826,7 @@ class Gemma3Model(PaliGemmaModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } if token_type_ids is not None and inputs_embeds.shape[1] != 1: # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` @@ -1034,6 +1036,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): attention_mask: Optional[torch.Tensor], cache_position: torch.Tensor, past_key_values: Optional[Cache], + position_ids: Optional[torch.Tensor], token_type_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: @@ -1044,6 +1047,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Add the token type ids mask for generate as well if token_type_ids is not None and input_embeds.shape[1] != 1: diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3a4995610d..9ba504a5bd 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -1649,6 +1649,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index b0a5099ff5..e65a7696ec 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -2088,6 +2088,7 @@ class Gemma3nTextModel(Gemma3TextModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 17deed6bc7..47d42d58e4 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -437,6 +437,7 @@ class GlmModel(GlmPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index ddb1592388..00c0f9ab59 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -445,6 +445,7 @@ class Glm4Model(Glm4PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index a97d9c0f82..1c2be3fdcd 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -902,6 +902,7 @@ class Glm4vTextModel(Glm4vPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 5732503daa..2d296d53a9 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -956,6 +956,7 @@ class Glm4vTextModel(Qwen2_5_VLTextModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 2e563e401f..1493c1cc11 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -392,6 +392,7 @@ class GPTNeoXModel(GPTNeoXPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) # Prepare head mask if needed diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index fde2677b4e..03c8300ed0 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -338,6 +338,7 @@ class GPTNeoXModel(LlamaModel, nn.Module): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) # Prepare head mask if needed diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index d5e7e4fe51..37ede89bd4 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -440,6 +440,7 @@ class GraniteModel(GranitePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 0432b9e2a5..c91cd4b12b 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -181,6 +181,7 @@ class GraniteModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index fb8fc6f3be..cb9a4b268e 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -422,6 +422,7 @@ class HeliumModel(HeliumPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 74c9651aca..78ceb22ee6 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -422,6 +422,7 @@ class LlamaModel(LlamaPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 259482934f..0e52700c2f 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -552,6 +552,7 @@ class Llama4TextModel(Llama4PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 45c6ed1081..abfb4892d7 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1120,6 +1120,7 @@ class MimiTransformerModel(nn.Module): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) # decoder layers diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 66ed4adcea..190fc8e529 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -731,6 +731,7 @@ class MiniMaxModel(MiniMaxPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 9b6fc12ae3..ed88d10bb9 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -536,6 +536,7 @@ class MiniMaxModel(MixtralModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4b222eabe2..1bf4ea1f1e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -399,6 +399,7 @@ class MistralModel(MistralPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index d78cc22ca2..2cd2be1eaa 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -160,6 +160,7 @@ class MistralModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 526bf2bbd7..10520f6d4c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -520,6 +520,7 @@ class MixtralModel(MixtralPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index cd774a5597..cc9bfb5297 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -370,6 +370,7 @@ class MixtralModel(MistralModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 2909fb386f..4f33ee6e2b 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -743,6 +743,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 500231f3b4..99ebb09c72 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -749,6 +749,7 @@ class MoonshineDecoder(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index fc6a718862..1e2c9f6bbc 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -401,6 +401,7 @@ class OlmoModel(OlmoPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 84f5e5ad4e..97559927de 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -407,6 +407,7 @@ class Olmo2Model(Olmo2PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 1c51360440..f040199742 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -395,6 +395,7 @@ class PhiModel(PhiPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 46a367bbdb..93690075ae 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -245,6 +245,7 @@ class PhiModel(LlamaModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 54fd3d1caf..7113ef56f7 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -454,6 +454,7 @@ class Phi3Model(Phi3PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 27c199bf50..ce49f8f901 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1762,6 +1762,7 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 28cd2de9fb..de85fc8727 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -1577,6 +1577,7 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 4ba0b43e13..01432c3ca9 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -406,6 +406,7 @@ class Qwen2Model(Qwen2PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index f73aedc2c0..114f6b7fea 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -167,6 +167,7 @@ class Qwen2Model(MistralModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 37a290e800..aab4261046 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1629,6 +1629,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { @@ -2189,6 +2190,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 1b69973cc3..90f99a49bd 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -893,6 +893,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 311622b222..5e77ffd329 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -868,6 +868,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index e64f966759..7fbb5a90d0 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -432,6 +432,7 @@ class Qwen3Model(Qwen3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 47ec0d10ab..5ba2bccd11 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -527,6 +527,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 30b566be3e..27f2b1aa41 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -436,6 +436,7 @@ class SmolLM3Model(SmolLM3PreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values, + "position_ids": position_ids, } # Create the masks causal_mask_mapping = { diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 1e1d9c6436..c6af5bfd94 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -400,6 +400,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 7226743713..5c63385958 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -214,6 +214,7 @@ class Starcoder2Model(MistralModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) hidden_states = inputs_embeds diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index feccf6d7d9..1eacec6a27 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -731,6 +731,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": None, + "position_ids": position_ids, } # Create the masks self_attn_mask_mapping = { @@ -874,6 +875,7 @@ class T5GemmaDecoder(T5GemmaEncoder): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, } # Create the masks self_attn_mask_mapping = { @@ -890,6 +892,7 @@ class T5GemmaDecoder(T5GemmaEncoder): "attention_mask": encoder_attention_mask, "cache_position": cache_position, "past_key_values": None, + "position_ids": None, } cross_attn_mask_mapping = { "full_attention": create_causal_mask( diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 970816fc38..01d8a401f4 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -678,6 +678,7 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": None, + "position_ids": position_ids, } # Create the masks self_attn_mask_mapping = { @@ -821,6 +822,7 @@ class T5GemmaDecoder(T5GemmaEncoder): "attention_mask": attention_mask, "cache_position": cache_position, "past_key_values": past_key_values.self_attention_cache if past_key_values is not None else None, + "position_ids": position_ids, } # Create the masks self_attn_mask_mapping = { @@ -837,6 +839,7 @@ class T5GemmaDecoder(T5GemmaEncoder): "attention_mask": encoder_attention_mask, "cache_position": cache_position, "past_key_values": None, + "position_ids": None, } cross_attn_mask_mapping = { "full_attention": create_causal_mask( diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d3e9c8e03a..43f1eccdc0 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -927,6 +927,7 @@ class WhisperDecoder(WhisperPreTrainedModel): attention_mask=attention_mask, cache_position=cache_position, past_key_values=past_key_values, + position_ids=position_ids, ) if self.gradient_checkpointing and self.training: diff --git a/tests/utils/test_masking_utils.py b/tests/utils/test_masking_utils.py new file mode 100644 index 0000000000..3b162e0b08 --- /dev/null +++ b/tests/utils/test_masking_utils.py @@ -0,0 +1,132 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers.testing_utils import is_torch_available, require_torch + + +if is_torch_available(): + import torch + from torch.nn.attention.flex_attention import create_block_mask + + from transformers import LlamaConfig + from transformers.masking_utils import create_causal_mask + + +# fmt: off +EXPECTED_PACKED_MASK = torch.tensor([[[ + [ True, False, False, False, False, False, False, False, False, False], + [ True, True, False, False, False, False, False, False, False, False], + [ True, True, True, False, False, False, False, False, False, False], + [ True, True, True, True, False, False, False, False, False, False], + [False, False, False, False, True, False, False, False, False, False], + [False, False, False, False, True, True, False, False, False, False], + [False, False, False, False, False, False, True, False, False, False], + [False, False, False, False, False, False, True, True, False, False], + [False, False, False, False, False, False, True, True, True, False], + [False, False, False, False, False, False, True, True, True, True]]], + + + [[[ True, False, False, False, False, False, False, False, False, False], + [ True, True, False, False, False, False, False, False, False, False], + [ True, True, True, False, False, False, False, False, False, False], + [ True, True, True, True, False, False, False, False, False, False], + [ True, True, True, True, True, False, False, False, False, False], + [ True, True, True, True, True, True, False, False, False, False], + [False, False, False, False, False, False, True, False, False, False], + [False, False, False, False, False, False, True, True, False, False], + [False, False, False, False, False, False, True, True, True, False], + [False, False, False, False, False, False, True, True, True, True] +]]], dtype=torch.bool) +# fmt: on + + +@require_torch +class MaskTest(unittest.TestCase): + def test_packed_sequence_mask_sdpa(self): + config = LlamaConfig() + config._attn_implementation = "sdpa" + + batch_size = 2 + sequence_length = 10 + cache_position = torch.arange(sequence_length) + + # First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens + position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) + + causal_mask = create_causal_mask( + config=config, + # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings + input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), + attention_mask=None, + cache_position=cache_position, + past_key_values=None, + position_ids=position_ids, + ) + + self.assertTrue((causal_mask == EXPECTED_PACKED_MASK).all()) + + def test_packed_sequence_mask_eager(self): + config = LlamaConfig() + config._attn_implementation = "eager" + + batch_size = 2 + sequence_length = 10 + cache_position = torch.arange(sequence_length) + + # First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens + position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) + + causal_mask = create_causal_mask( + config=config, + # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings + input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), + attention_mask=None, + cache_position=cache_position, + past_key_values=None, + position_ids=position_ids, + ) + + min_dtype = torch.finfo(torch.float16).min + self.assertTrue((causal_mask == torch.where(EXPECTED_PACKED_MASK, 0.0, min_dtype)).all()) + + def test_packed_sequence_mask_flex_attention(self): + config = LlamaConfig() + config._attn_implementation = "flex_attention" + + batch_size = 2 + sequence_length = 10 + cache_position = torch.arange(sequence_length) + + # First batch has 3 packed sequences of 4, 2 and 4 tokens respectively, second has 2 of 6 and 4 tokens + position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 0, 1, 2, 3]]) + + causal_mask = create_causal_mask( + config=config, + # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings + input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16), + attention_mask=None, + cache_position=cache_position, + past_key_values=None, + position_ids=position_ids, + ) + + def dummy_mask_mod(b, h, q, kv): + return EXPECTED_PACKED_MASK[b, h, q, kv] + + EXPECTED_BLOCK_MASK = create_block_mask(dummy_mask_mod, 2, None, 10, 10, device="cpu") + + # We compatre the str representations, as the BlockMask objects themselves cannot easily be compared + self.assertEqual(causal_mask.to_string(), EXPECTED_BLOCK_MASK.to_string())