[qwen2 vl] fix packing with all attentions (#39447)

* fix qwen2 vl packing in FA2

* why? delete!

* qwen2-5-vl seems to work now

* update

* fix tests

* start by adapting FA2 tests

* add similar tests for sdpa/eager

* address comments

* why is this even in conditional model and not base model?
This commit is contained in:
Raushan Turganbay
2025-07-21 12:19:15 +02:00
committed by GitHub
parent e42681b48b
commit 344012b3a6
9 changed files with 478 additions and 100 deletions

View File

@@ -40,7 +40,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelO
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
from ...utils import auto_docstring, check_torch_load_is_safe, logging from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
from ...utils.hub import cached_file from ...utils.hub import cached_file
from .configuration_qwen2_5_omni import ( from .configuration_qwen2_5_omni import (
Qwen2_5OmniAudioEncoderConfig, Qwen2_5OmniAudioEncoderConfig,
@@ -1424,6 +1424,7 @@ class Qwen2_5OmniAttention(nn.Module):
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
position_ids=position_ids, # pass positions for FA2
**kwargs, **kwargs,
) )
@@ -1607,9 +1608,25 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
# the hard coded `3` is for temporal, height and width. # the hard coded `3` is for temporal, height and width.
if position_ids is None: if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.dim() == 2: elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
# where each dim indicates visual spatial positions for temporal/height/width grids.
# There are two scenarios when FA2-like packed masking might be activated.
# 1. User specifically passed packed `position_ids` and no attention mask.
# In this case we expect the useer to create correct position ids for all 3 grids
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
# It may already have been prepared by e.g. `generate` # It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict): if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments # Prepare mask arguments
@@ -1619,7 +1636,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"position_ids": position_ids, "position_ids": text_position_ids,
} }
# Create the masks # Create the masks
causal_mask_mapping = { causal_mask_mapping = {
@@ -1645,7 +1662,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type], attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=text_position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1804,6 +1821,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
use_audio_in_video: Optional[bool] = None, use_audio_in_video: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
video_second_per_grid: Optional[torch.LongTensor] = None, video_second_per_grid: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: ) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
r""" r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
@@ -1959,6 +1977,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
hidden_states = outputs[0] hidden_states = outputs[0]
@@ -2146,9 +2165,25 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
# the hard coded `3` is for temporal, height and width. # the hard coded `3` is for temporal, height and width.
if position_ids is None: if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.dim() == 2: elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
# where each dim indicates visual spatial positions for temporal/height/width grids.
# There are two scenarios when FA2-like packed masking might be activated.
# 1. User specifically passed packed `position_ids` and no attention mask.
# In this case we expect the useer to create correct position ids for all 3 grids
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
# It may already have been prepared by e.g. `generate` # It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict): if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments # Prepare mask arguments
@@ -2158,7 +2193,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"position_ids": position_ids, "position_ids": text_position_ids,
} }
# Create the masks # Create the masks
causal_mask_mapping = { causal_mask_mapping = {
@@ -2184,7 +2219,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type], attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=text_position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,

View File

@@ -49,7 +49,9 @@ from ...modeling_flash_attention_utils import is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_rope_utils import rope_config_validation from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import ( from ...utils import (
TransformersKwargs,
auto_docstring, auto_docstring,
check_torch_load_is_safe, check_torch_load_is_safe,
logging, logging,
@@ -2259,6 +2261,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
use_audio_in_video: Optional[bool] = None, use_audio_in_video: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
video_second_per_grid: Optional[torch.LongTensor] = None, video_second_per_grid: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]: ) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
r""" r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
@@ -2414,6 +2417,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position, cache_position=cache_position,
**kwargs,
) )
hidden_states = outputs[0] hidden_states = outputs[0]

View File

@@ -710,6 +710,7 @@ class Qwen2_5_VLAttention(nn.Module):
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
position_ids=position_ids, # pass positions for FA2
**kwargs, **kwargs,
) )
@@ -878,9 +879,25 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
# the hard coded `3` is for temporal, height and width. # the hard coded `3` is for temporal, height and width.
if position_ids is None: if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.dim() == 2: elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
# where each dim indicates visual spatial positions for temporal/height/width grids.
# There are two scenarios when FA2-like packed masking might be activated.
# 1. User specifically passed packed `position_ids` and no attention mask.
# In this case we expect the useer to create correct position ids for all 3 grids
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
# It may already have been prepared by e.g. `generate` # It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict): if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments # Prepare mask arguments
@@ -890,7 +907,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"position_ids": position_ids, "position_ids": text_position_ids,
} }
# Create the masks # Create the masks
causal_mask_mapping = { causal_mask_mapping = {
@@ -916,7 +933,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type], attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=text_position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1279,16 +1296,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
# Only apply conversion for floating point tensors (inverted masks)
if attention_mask_tensor.dtype.is_floating_point:
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
# Calculate RoPE index once per generation in the pre-fill stage only. # Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length # When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled # It is safe to assume that `length!=1` means we're in pre-fill because compiled
@@ -1307,23 +1314,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
image_grid_thw, image_grid_thw,
video_grid_thw, video_grid_thw,
second_per_grid_ts=second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask_tensor, attention_mask=attention_mask,
) )
self.rope_deltas = rope_deltas self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else: else:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0` if cache_position is not None:
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
position_ids = position_ids.add(delta) else:
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids += delta.to(position_ids.device)
outputs = self.language_model( outputs = self.language_model(
input_ids=None, input_ids=None,
@@ -1573,8 +1576,35 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
**kwargs, **kwargs,
) )
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward # Qwen2-5-VL position_ids are prepared with rope_deltas
model_inputs["position_ids"] = None if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
if cache_position[0] == 0 or self.model.rope_deltas is None:
vision_positions, rope_deltas = self.model.get_rope_index(
model_inputs.get("input_ids", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.model.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
elif "position_ids" in model_inputs:
position_ids = model_inputs["position_ids"][None, ...]
delta = self.model.rope_deltas
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
vision_positions = position_ids + delta.expand_as(position_ids)
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
# Concatenate "text + vision" positions into [4, bs, seq-len]
if "position_ids" not in model_inputs:
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
else:
text_positions = model_inputs["position_ids"][None, ...]
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
if cache_position[0] != 0: if cache_position[0] != 0:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None

View File

@@ -630,16 +630,6 @@ class Qwen2_5_VLModel(Qwen2VLModel):
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = (
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
# Only apply conversion for floating point tensors (inverted masks)
if attention_mask_tensor.dtype.is_floating_point:
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
# Calculate RoPE index once per generation in the pre-fill stage only. # Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length # When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled # It is safe to assume that `length!=1` means we're in pre-fill because compiled
@@ -658,23 +648,19 @@ class Qwen2_5_VLModel(Qwen2VLModel):
image_grid_thw, image_grid_thw,
video_grid_thw, video_grid_thw,
second_per_grid_ts=second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask_tensor, attention_mask=attention_mask,
) )
self.rope_deltas = rope_deltas self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else: else:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0` if cache_position is not None:
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
position_ids = position_ids.add(delta) else:
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids += delta.to(position_ids.device)
outputs = self.language_model( outputs = self.language_model(
input_ids=None, input_ids=None,
@@ -848,8 +834,35 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
**kwargs, **kwargs,
) )
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward # Qwen2-5-VL position_ids are prepared with rope_deltas
model_inputs["position_ids"] = None if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
if cache_position[0] == 0 or self.model.rope_deltas is None:
vision_positions, rope_deltas = self.model.get_rope_index(
model_inputs.get("input_ids", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
attention_mask=attention_mask,
)
self.model.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
elif "position_ids" in model_inputs:
position_ids = model_inputs["position_ids"][None, ...]
delta = self.model.rope_deltas
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
vision_positions = position_ids + delta.expand_as(position_ids)
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
# Concatenate "text + vision" positions into [4, bs, seq-len]
if "position_ids" not in model_inputs:
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
else:
text_positions = model_inputs["position_ids"][None, ...]
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
if cache_position[0] != 0: if cache_position[0] != 0:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None

View File

@@ -558,6 +558,7 @@ class Qwen2VLAttention(nn.Module):
dropout=0.0 if not self.training else self.attention_dropout, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, scaling=self.scaling,
sliding_window=self.sliding_window, sliding_window=self.sliding_window,
position_ids=position_ids, # pass positions for FA2
**kwargs, **kwargs,
) )
@@ -853,9 +854,25 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
# the hard coded `3` is for temporal, height and width. # the hard coded `3` is for temporal, height and width.
if position_ids is None: if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
elif position_ids.dim() == 2: elif position_ids.ndim == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
# where each dim indicates visual spatial positions for temporal/height/width grids.
# There are two scenarios when FA2-like packed masking might be activated.
# 1. User specifically passed packed `position_ids` and no attention mask.
# In this case we expect the useer to create correct position ids for all 3 grids
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
text_position_ids = position_ids[0]
position_ids = position_ids[1:]
else:
text_position_ids = position_ids[0]
# It may already have been prepared by e.g. `generate` # It may already have been prepared by e.g. `generate`
if not isinstance(causal_mask_mapping := attention_mask, dict): if not isinstance(causal_mask_mapping := attention_mask, dict):
# Prepare mask arguments # Prepare mask arguments
@@ -865,7 +882,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"cache_position": cache_position, "cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"position_ids": position_ids, "position_ids": text_position_ids,
} }
# Create the masks # Create the masks
causal_mask_mapping = { causal_mask_mapping = {
@@ -891,7 +908,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type], attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids, position_ids=text_position_ids,
past_key_value=past_key_values, past_key_value=past_key_values,
output_attentions=output_attentions, output_attentions=output_attentions,
use_cache=use_cache, use_cache=use_cache,
@@ -1217,44 +1234,22 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if position_ids is None: if position_ids is None:
attention_mask_tensor = ( if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
)
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
# Only apply conversion for floating point tensors (inverted masks)
if attention_mask_tensor.dtype.is_floating_point:
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index( position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask_tensor input_ids, image_grid_thw, video_grid_thw, attention_mask
) )
self.rope_deltas = rope_deltas self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids # then use the prev pre-calculated rope-deltas to get the correct position ids
else: else:
batch_size, seq_length, _ = inputs_embeds.shape batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1) position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0` if cache_position is not None:
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
delta = delta.to(position_ids.device) else:
position_ids = position_ids.add(delta) delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids += delta.to(position_ids.device)
outputs = self.language_model( outputs = self.language_model(
input_ids=None, input_ids=None,
@@ -1465,7 +1460,41 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
) )
# Qwen2-VL position_ids are prepareed with rope_deltas in forward # Qwen2-VL position_ids are prepareed with rope_deltas in forward
model_inputs["position_ids"] = None if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.model.rope_deltas is None:
vision_positions, rope_deltas = self.model.get_rope_index(
model_inputs.get("input_ids", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
)
self.model.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
elif "position_ids" in model_inputs:
position_ids = model_inputs["position_ids"][None, ...]
delta = self.model.rope_deltas
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
vision_positions = position_ids + delta.expand_as(position_ids)
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
# Concatenate "text + vision" positions into [4, bs, seq-len]
if "position_ids" not in model_inputs:
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
else:
text_positions = model_inputs["position_ids"][None, ...]
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
if model_inputs["cache_position"][0] != 0: if model_inputs["cache_position"][0] != 0:
model_inputs["pixel_values"] = None model_inputs["pixel_values"] = None

View File

@@ -332,6 +332,92 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
raise ValueError("The eager model should not have SDPA attention layers") raise ValueError("The eager model should not have SDPA attention layers")
def flash_attention_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
)
.to(torch_device)
.eval()
)
# flatten
padfree_inputs_dict = {
"input_features": inputs_dict["input_features"],
"feature_attention_mask": inputs_dict["feature_attention_mask"],
"pixel_values": inputs_dict["pixel_values"],
"image_grid_thw": inputs_dict["image_grid_thw"],
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
}
# add position_ids
vision_position_ids, deltas = model.get_rope_index(
input_ids=inputs_dict["input_ids"],
image_grid_thw=inputs_dict["image_grid_thw"],
attention_mask=inputs_dict["attention_mask"],
audio_seqlens=torch.sum(inputs_dict["feature_attention_mask"], dim=1),
) # [3, bs, padded-seq-len]
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
3, -1
) # [3, bs*padfree-len]
text_padfree_positions = torch.cat(
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
) # [1, bs*padfree-len]
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
:, None, :
]
if fa_kwargs:
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
max_length = cu_seq_lens.diff().max().item()
padfree_inputs_dict.update(
{
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
"max_length_q": max_length,
"max_length_k": max_length,
}
)
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@unittest.skip("Cannot do contrastive generation, has custom `generate()`") @unittest.skip("Cannot do contrastive generation, has custom `generate()`")
def test_contrastive_generate(self): def test_contrastive_generate(self):
pass pass

View File

@@ -325,6 +325,89 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
) )
self.assertIsNotNone(outputs) self.assertIsNotNone(outputs)
def flash_attention_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
)
.to(torch_device)
.eval()
)
# flatten
padfree_inputs_dict = {
"pixel_values": inputs_dict["pixel_values"],
"image_grid_thw": inputs_dict["image_grid_thw"],
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
}
# add position_ids
vision_position_ids, deltas = model.model.get_rope_index(
input_ids=inputs_dict["input_ids"],
image_grid_thw=inputs_dict["image_grid_thw"],
attention_mask=inputs_dict["attention_mask"],
) # [3, bs, padded-seq-len]
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
3, -1
) # [3, bs*padfree-len]
text_padfree_positions = torch.cat(
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
) # [1, bs*padfree-len]
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
:, None, :
]
if fa_kwargs:
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
max_length = cu_seq_lens.diff().max().item()
padfree_inputs_dict.update(
{
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
"max_length_q": max_length,
"max_length_k": max_length,
}
)
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@unittest.skip(reason="Feedforward chunking is not yet supported") @unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass

View File

@@ -15,6 +15,7 @@
import copy import copy
import gc import gc
import tempfile
import unittest import unittest
import requests import requests
@@ -168,6 +169,7 @@ class Qwen2VLVisionText2TextModelTester:
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[:, -1] = self.pad_token_id input_ids[:, -1] = self.pad_token_id
attention_mask[:, -1] = 0
input_ids[input_ids == self.video_token_id] = self.pad_token_id input_ids[input_ids == self.video_token_id] = self.pad_token_id
input_ids[input_ids == self.image_token_id] = self.pad_token_id input_ids[input_ids == self.image_token_id] = self.pad_token_id
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
@@ -281,6 +283,90 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4 generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
) )
def flash_attention_padding_matches_padding_free_with_position_ids(
self, attn_implementation: str, fa_kwargs: bool = False
):
max_new_tokens = 30
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.float16]:
dummy_input = dummy_input.to(torch.bfloat16)
# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
model = model_class(config)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
model = (
model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
)
.to(torch_device)
.eval()
)
# flatten
padfree_inputs_dict = {
"pixel_values": inputs_dict["pixel_values"],
"image_grid_thw": inputs_dict["image_grid_thw"],
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
}
# add position_ids
vision_position_ids, deltas = model.model.get_rope_index(
input_ids=inputs_dict["input_ids"],
image_grid_thw=inputs_dict["image_grid_thw"],
attention_mask=inputs_dict["attention_mask"],
) # [3, bs, padded-seq-len]
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
3, -1
) # [3, bs*padfree-len]
text_padfree_positions = torch.cat(
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
) # [1, bs*padfree-len]
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
:, None, :
]
if fa_kwargs:
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
max_length = cu_seq_lens.diff().max().item()
padfree_inputs_dict.update(
{
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
"max_length_q": max_length,
"max_length_k": max_length,
}
)
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
@unittest.skip(reason="Feedforward chunking is not yet supported") @unittest.skip(reason="Feedforward chunking is not yet supported")
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass

View File

@@ -4129,13 +4129,14 @@ class ModelTesterMixin:
self.skipTest(reason="Model architecture does not support attentions") self.skipTest(reason="Model architecture does not support attentions")
max_new_tokens = 30 max_new_tokens = 30
support_flag = {
"sdpa": "_supports_sdpa",
"flash_attention_2": "_supports_flash_attn",
"flash_attention_3": "_supports_flash_attn",
}
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
if not ( if not getattr(model_class, support_flag[attn_implementation]):
model_class._supports_flash_attn_2
if attn_implementation == "flash_attention_2"
else model_class._supports_flash_attn_3
):
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -4204,8 +4205,9 @@ class ModelTesterMixin:
.to(torch_device) .to(torch_device)
) )
res_padded = model(**inputs_dict) # We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
res_padfree = model(**padfree_inputs_dict) res_padded = model(**inputs_dict, use_cache=False)
res_padfree = model(**padfree_inputs_dict, use_cache=False)
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0] logits_padfree = res_padfree.logits[0]
@@ -4215,6 +4217,16 @@ class ModelTesterMixin:
tol = torch.finfo(torch.bfloat16).eps tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
# FIXME @raushan
@slow
def test_eager_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
@slow
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
@require_flash_attn @require_flash_attn
@require_torch_gpu @require_torch_gpu
@mark.flash_attn_test @mark.flash_attn_test