[qwen2 vl] fix packing with all attentions (#39447)
* fix qwen2 vl packing in FA2 * why? delete! * qwen2-5-vl seems to work now * update * fix tests * start by adapting FA2 tests * add similar tests for sdpa/eager * address comments * why is this even in conditional model and not base model?
This commit is contained in:
committed by
GitHub
parent
e42681b48b
commit
344012b3a6
@@ -40,7 +40,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelO
|
|||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...processing_utils import Unpack
|
from ...processing_utils import Unpack
|
||||||
from ...utils import auto_docstring, check_torch_load_is_safe, logging
|
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
|
||||||
from ...utils.hub import cached_file
|
from ...utils.hub import cached_file
|
||||||
from .configuration_qwen2_5_omni import (
|
from .configuration_qwen2_5_omni import (
|
||||||
Qwen2_5OmniAudioEncoderConfig,
|
Qwen2_5OmniAudioEncoderConfig,
|
||||||
@@ -1424,6 +1424,7 @@ class Qwen2_5OmniAttention(nn.Module):
|
|||||||
dropout=0.0 if not self.training else self.attention_dropout,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
scaling=self.scaling,
|
scaling=self.scaling,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
position_ids=position_ids, # pass positions for FA2
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1607,9 +1608,25 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
|||||||
# the hard coded `3` is for temporal, height and width.
|
# the hard coded `3` is for temporal, height and width.
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||||
elif position_ids.dim() == 2:
|
elif position_ids.ndim == 2:
|
||||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||||
|
|
||||||
|
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||||
|
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||||
|
# There are two scenarios when FA2-like packed masking might be activated.
|
||||||
|
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||||
|
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||||
|
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||||
|
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||||
|
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||||
|
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||||
|
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||||
|
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
position_ids = position_ids[1:]
|
||||||
|
else:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
|
||||||
# It may already have been prepared by e.g. `generate`
|
# It may already have been prepared by e.g. `generate`
|
||||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||||
# Prepare mask arguments
|
# Prepare mask arguments
|
||||||
@@ -1619,7 +1636,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"position_ids": position_ids,
|
"position_ids": text_position_ids,
|
||||||
}
|
}
|
||||||
# Create the masks
|
# Create the masks
|
||||||
causal_mask_mapping = {
|
causal_mask_mapping = {
|
||||||
@@ -1645,7 +1662,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=text_position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1804,6 +1821,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
use_audio_in_video: Optional[bool] = None,
|
use_audio_in_video: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
video_second_per_grid: Optional[torch.LongTensor] = None,
|
video_second_per_grid: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
@@ -1959,6 +1977,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@@ -2146,9 +2165,25 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
|||||||
# the hard coded `3` is for temporal, height and width.
|
# the hard coded `3` is for temporal, height and width.
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||||
elif position_ids.dim() == 2:
|
elif position_ids.ndim == 2:
|
||||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||||
|
|
||||||
|
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||||
|
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||||
|
# There are two scenarios when FA2-like packed masking might be activated.
|
||||||
|
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||||
|
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||||
|
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||||
|
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||||
|
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||||
|
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||||
|
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||||
|
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
position_ids = position_ids[1:]
|
||||||
|
else:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
|
||||||
# It may already have been prepared by e.g. `generate`
|
# It may already have been prepared by e.g. `generate`
|
||||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||||
# Prepare mask arguments
|
# Prepare mask arguments
|
||||||
@@ -2158,7 +2193,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"position_ids": position_ids,
|
"position_ids": text_position_ids,
|
||||||
}
|
}
|
||||||
# Create the masks
|
# Create the masks
|
||||||
causal_mask_mapping = {
|
causal_mask_mapping = {
|
||||||
@@ -2184,7 +2219,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=text_position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
|||||||
@@ -49,7 +49,9 @@ from ...modeling_flash_attention_utils import is_flash_attn_available
|
|||||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||||
from ...modeling_rope_utils import rope_config_validation
|
from ...modeling_rope_utils import rope_config_validation
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
from ...processing_utils import Unpack
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
|
TransformersKwargs,
|
||||||
auto_docstring,
|
auto_docstring,
|
||||||
check_torch_load_is_safe,
|
check_torch_load_is_safe,
|
||||||
logging,
|
logging,
|
||||||
@@ -2259,6 +2261,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
use_audio_in_video: Optional[bool] = None,
|
use_audio_in_video: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
video_second_per_grid: Optional[torch.LongTensor] = None,
|
video_second_per_grid: Optional[torch.LongTensor] = None,
|
||||||
|
**kwargs: Unpack[TransformersKwargs],
|
||||||
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||||
@@ -2414,6 +2417,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
|||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
|||||||
@@ -710,6 +710,7 @@ class Qwen2_5_VLAttention(nn.Module):
|
|||||||
dropout=0.0 if not self.training else self.attention_dropout,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
scaling=self.scaling,
|
scaling=self.scaling,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
position_ids=position_ids, # pass positions for FA2
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -878,9 +879,25 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
# the hard coded `3` is for temporal, height and width.
|
# the hard coded `3` is for temporal, height and width.
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||||
elif position_ids.dim() == 2:
|
elif position_ids.ndim == 2:
|
||||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||||
|
|
||||||
|
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||||
|
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||||
|
# There are two scenarios when FA2-like packed masking might be activated.
|
||||||
|
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||||
|
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||||
|
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||||
|
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||||
|
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||||
|
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||||
|
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||||
|
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
position_ids = position_ids[1:]
|
||||||
|
else:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
|
||||||
# It may already have been prepared by e.g. `generate`
|
# It may already have been prepared by e.g. `generate`
|
||||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||||
# Prepare mask arguments
|
# Prepare mask arguments
|
||||||
@@ -890,7 +907,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"position_ids": position_ids,
|
"position_ids": text_position_ids,
|
||||||
}
|
}
|
||||||
# Create the masks
|
# Create the masks
|
||||||
causal_mask_mapping = {
|
causal_mask_mapping = {
|
||||||
@@ -916,7 +933,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=text_position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1279,16 +1296,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
attention_mask_tensor = (
|
|
||||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
|
||||||
)
|
|
||||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
||||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
||||||
# Only apply conversion for floating point tensors (inverted masks)
|
|
||||||
if attention_mask_tensor.dtype.is_floating_point:
|
|
||||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
||||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
||||||
|
|
||||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
# When compiling, we can't check tensor values thus we check only input length
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
@@ -1307,23 +1314,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
|||||||
image_grid_thw,
|
image_grid_thw,
|
||||||
video_grid_thw,
|
video_grid_thw,
|
||||||
second_per_grid_ts=second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
attention_mask=attention_mask_tensor,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
self.rope_deltas = rope_deltas
|
self.rope_deltas = rope_deltas
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
||||||
else:
|
else:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
delta = (
|
|
||||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
||||||
if cache_position is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
if cache_position is not None:
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||||
position_ids = position_ids.add(delta)
|
else:
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
|
||||||
|
position_ids += delta.to(position_ids.device)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -1573,8 +1576,35 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
|
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
||||||
model_inputs["position_ids"] = None
|
if position_ids is None:
|
||||||
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
|
# models currently cannot do asssisted decoding
|
||||||
|
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
||||||
|
vision_positions, rope_deltas = self.model.get_rope_index(
|
||||||
|
model_inputs.get("input_ids", None),
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
self.model.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
elif "position_ids" in model_inputs:
|
||||||
|
position_ids = model_inputs["position_ids"][None, ...]
|
||||||
|
delta = self.model.rope_deltas
|
||||||
|
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
|
||||||
|
vision_positions = position_ids + delta.expand_as(position_ids)
|
||||||
|
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
|
||||||
|
|
||||||
|
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
||||||
|
if "position_ids" not in model_inputs:
|
||||||
|
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
|
||||||
|
else:
|
||||||
|
text_positions = model_inputs["position_ids"][None, ...]
|
||||||
|
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
||||||
|
|
||||||
if cache_position[0] != 0:
|
if cache_position[0] != 0:
|
||||||
model_inputs["pixel_values"] = None
|
model_inputs["pixel_values"] = None
|
||||||
|
|||||||
@@ -630,16 +630,6 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
attention_mask_tensor = (
|
|
||||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
|
||||||
)
|
|
||||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
||||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
||||||
# Only apply conversion for floating point tensors (inverted masks)
|
|
||||||
if attention_mask_tensor.dtype.is_floating_point:
|
|
||||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
||||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
||||||
|
|
||||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
# When compiling, we can't check tensor values thus we check only input length
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
@@ -658,23 +648,19 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
|||||||
image_grid_thw,
|
image_grid_thw,
|
||||||
video_grid_thw,
|
video_grid_thw,
|
||||||
second_per_grid_ts=second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
attention_mask=attention_mask_tensor,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
self.rope_deltas = rope_deltas
|
self.rope_deltas = rope_deltas
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
||||||
else:
|
else:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
delta = (
|
|
||||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
|
||||||
if cache_position is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
if cache_position is not None:
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||||
position_ids = position_ids.add(delta)
|
else:
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
|
||||||
|
position_ids += delta.to(position_ids.device)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -848,8 +834,35 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
|
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
||||||
model_inputs["position_ids"] = None
|
if position_ids is None:
|
||||||
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
|
# models currently cannot do asssisted decoding
|
||||||
|
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
||||||
|
vision_positions, rope_deltas = self.model.get_rope_index(
|
||||||
|
model_inputs.get("input_ids", None),
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
self.model.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
elif "position_ids" in model_inputs:
|
||||||
|
position_ids = model_inputs["position_ids"][None, ...]
|
||||||
|
delta = self.model.rope_deltas
|
||||||
|
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
|
||||||
|
vision_positions = position_ids + delta.expand_as(position_ids)
|
||||||
|
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
|
||||||
|
|
||||||
|
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
||||||
|
if "position_ids" not in model_inputs:
|
||||||
|
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
|
||||||
|
else:
|
||||||
|
text_positions = model_inputs["position_ids"][None, ...]
|
||||||
|
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
||||||
|
|
||||||
if cache_position[0] != 0:
|
if cache_position[0] != 0:
|
||||||
model_inputs["pixel_values"] = None
|
model_inputs["pixel_values"] = None
|
||||||
|
|||||||
@@ -558,6 +558,7 @@ class Qwen2VLAttention(nn.Module):
|
|||||||
dropout=0.0 if not self.training else self.attention_dropout,
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
scaling=self.scaling,
|
scaling=self.scaling,
|
||||||
sliding_window=self.sliding_window,
|
sliding_window=self.sliding_window,
|
||||||
|
position_ids=position_ids, # pass positions for FA2
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -853,9 +854,25 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
|||||||
# the hard coded `3` is for temporal, height and width.
|
# the hard coded `3` is for temporal, height and width.
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||||
elif position_ids.dim() == 2:
|
elif position_ids.ndim == 2:
|
||||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||||
|
|
||||||
|
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||||
|
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||||
|
# There are two scenarios when FA2-like packed masking might be activated.
|
||||||
|
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||||
|
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||||
|
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||||
|
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||||
|
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||||
|
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||||
|
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||||
|
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
position_ids = position_ids[1:]
|
||||||
|
else:
|
||||||
|
text_position_ids = position_ids[0]
|
||||||
|
|
||||||
# It may already have been prepared by e.g. `generate`
|
# It may already have been prepared by e.g. `generate`
|
||||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||||
# Prepare mask arguments
|
# Prepare mask arguments
|
||||||
@@ -865,7 +882,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
|||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
"cache_position": cache_position,
|
"cache_position": cache_position,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"position_ids": position_ids,
|
"position_ids": text_position_ids,
|
||||||
}
|
}
|
||||||
# Create the masks
|
# Create the masks
|
||||||
causal_mask_mapping = {
|
causal_mask_mapping = {
|
||||||
@@ -891,7 +908,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
|||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||||
position_ids=position_ids,
|
position_ids=text_position_ids,
|
||||||
past_key_value=past_key_values,
|
past_key_value=past_key_values,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
@@ -1217,44 +1234,22 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
|||||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
attention_mask_tensor = (
|
if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
|
||||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
|
||||||
)
|
|
||||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
|
||||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
|
||||||
# Only apply conversion for floating point tensors (inverted masks)
|
|
||||||
if attention_mask_tensor.dtype.is_floating_point:
|
|
||||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
|
||||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
|
||||||
|
|
||||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
||||||
# When compiling, we can't check tensor values thus we check only input length
|
|
||||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
|
||||||
# models currently cannot do asssisted decoding
|
|
||||||
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
|
||||||
(input_ids is not None and input_ids.shape[1] != 1)
|
|
||||||
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
|
||||||
)
|
|
||||||
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
|
||||||
(cache_position is not None and cache_position[0] == 0)
|
|
||||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
|
||||||
)
|
|
||||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
input_ids, image_grid_thw, video_grid_thw, attention_mask_tensor
|
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||||
)
|
)
|
||||||
self.rope_deltas = rope_deltas
|
self.rope_deltas = rope_deltas
|
||||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
else:
|
else:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
|
||||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
||||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
if cache_position is not None:
|
||||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||||
delta = delta.to(position_ids.device)
|
else:
|
||||||
position_ids = position_ids.add(delta)
|
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||||
|
position_ids += delta.to(position_ids.device)
|
||||||
|
|
||||||
outputs = self.language_model(
|
outputs = self.language_model(
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -1465,7 +1460,41 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Qwen2-VL position_ids are prepareed with rope_deltas in forward
|
# Qwen2-VL position_ids are prepareed with rope_deltas in forward
|
||||||
model_inputs["position_ids"] = None
|
if position_ids is None:
|
||||||
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||||
|
# When compiling, we can't check tensor values thus we check only input length
|
||||||
|
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||||
|
# models currently cannot do asssisted decoding
|
||||||
|
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||||
|
(input_ids is not None and input_ids.shape[1] != 1)
|
||||||
|
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||||
|
)
|
||||||
|
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||||
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
|
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||||
|
)
|
||||||
|
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.model.rope_deltas is None:
|
||||||
|
vision_positions, rope_deltas = self.model.get_rope_index(
|
||||||
|
model_inputs.get("input_ids", None),
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
self.model.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
elif "position_ids" in model_inputs:
|
||||||
|
position_ids = model_inputs["position_ids"][None, ...]
|
||||||
|
delta = self.model.rope_deltas
|
||||||
|
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
|
||||||
|
vision_positions = position_ids + delta.expand_as(position_ids)
|
||||||
|
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
|
||||||
|
|
||||||
|
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
||||||
|
if "position_ids" not in model_inputs:
|
||||||
|
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
|
||||||
|
else:
|
||||||
|
text_positions = model_inputs["position_ids"][None, ...]
|
||||||
|
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
||||||
|
|
||||||
if model_inputs["cache_position"][0] != 0:
|
if model_inputs["cache_position"][0] != 0:
|
||||||
model_inputs["pixel_values"] = None
|
model_inputs["pixel_values"] = None
|
||||||
|
|||||||
@@ -332,6 +332,92 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
|||||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||||
raise ValueError("The eager model should not have SDPA attention layers")
|
raise ValueError("The eager model should not have SDPA attention layers")
|
||||||
|
|
||||||
|
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||||
|
self, attn_implementation: str, fa_kwargs: bool = False
|
||||||
|
):
|
||||||
|
max_new_tokens = 30
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# make sure that all models have enough positions for generation
|
||||||
|
if hasattr(config, "max_position_embeddings"):
|
||||||
|
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||||
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
|
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||||
|
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||||
|
|
||||||
|
model = (
|
||||||
|
model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
padfree_inputs_dict = {
|
||||||
|
"input_features": inputs_dict["input_features"],
|
||||||
|
"feature_attention_mask": inputs_dict["feature_attention_mask"],
|
||||||
|
"pixel_values": inputs_dict["pixel_values"],
|
||||||
|
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||||
|
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# add position_ids
|
||||||
|
vision_position_ids, deltas = model.get_rope_index(
|
||||||
|
input_ids=inputs_dict["input_ids"],
|
||||||
|
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||||
|
attention_mask=inputs_dict["attention_mask"],
|
||||||
|
audio_seqlens=torch.sum(inputs_dict["feature_attention_mask"], dim=1),
|
||||||
|
) # [3, bs, padded-seq-len]
|
||||||
|
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||||
|
3, -1
|
||||||
|
) # [3, bs*padfree-len]
|
||||||
|
text_padfree_positions = torch.cat(
|
||||||
|
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||||
|
) # [1, bs*padfree-len]
|
||||||
|
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||||
|
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||||
|
:, None, :
|
||||||
|
]
|
||||||
|
|
||||||
|
if fa_kwargs:
|
||||||
|
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||||
|
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||||
|
max_length = cu_seq_lens.diff().max().item()
|
||||||
|
padfree_inputs_dict.update(
|
||||||
|
{
|
||||||
|
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||||
|
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||||
|
"max_length_q": max_length,
|
||||||
|
"max_length_k": max_length,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
res_padded = model(**inputs_dict, use_cache=False)
|
||||||
|
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||||
|
|
||||||
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
|
logits_padfree = res_padfree.logits[0]
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||||
|
# acceptable numerical instability
|
||||||
|
tol = torch.finfo(torch.bfloat16).eps
|
||||||
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||||
def test_contrastive_generate(self):
|
def test_contrastive_generate(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -325,6 +325,89 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
)
|
)
|
||||||
self.assertIsNotNone(outputs)
|
self.assertIsNotNone(outputs)
|
||||||
|
|
||||||
|
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||||
|
self, attn_implementation: str, fa_kwargs: bool = False
|
||||||
|
):
|
||||||
|
max_new_tokens = 30
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# make sure that all models have enough positions for generation
|
||||||
|
if hasattr(config, "max_position_embeddings"):
|
||||||
|
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||||
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
|
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||||
|
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||||
|
|
||||||
|
model = (
|
||||||
|
model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
padfree_inputs_dict = {
|
||||||
|
"pixel_values": inputs_dict["pixel_values"],
|
||||||
|
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||||
|
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# add position_ids
|
||||||
|
vision_position_ids, deltas = model.model.get_rope_index(
|
||||||
|
input_ids=inputs_dict["input_ids"],
|
||||||
|
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||||
|
attention_mask=inputs_dict["attention_mask"],
|
||||||
|
) # [3, bs, padded-seq-len]
|
||||||
|
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||||
|
3, -1
|
||||||
|
) # [3, bs*padfree-len]
|
||||||
|
text_padfree_positions = torch.cat(
|
||||||
|
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||||
|
) # [1, bs*padfree-len]
|
||||||
|
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||||
|
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||||
|
:, None, :
|
||||||
|
]
|
||||||
|
|
||||||
|
if fa_kwargs:
|
||||||
|
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||||
|
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||||
|
max_length = cu_seq_lens.diff().max().item()
|
||||||
|
padfree_inputs_dict.update(
|
||||||
|
{
|
||||||
|
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||||
|
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||||
|
"max_length_q": max_length,
|
||||||
|
"max_length_k": max_length,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
res_padded = model(**inputs_dict, use_cache=False)
|
||||||
|
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||||
|
|
||||||
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
|
logits_padfree = res_padfree.logits[0]
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||||
|
# acceptable numerical instability
|
||||||
|
tol = torch.finfo(torch.bfloat16).eps
|
||||||
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -168,6 +169,7 @@ class Qwen2VLVisionText2TextModelTester:
|
|||||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||||
|
|
||||||
input_ids[:, -1] = self.pad_token_id
|
input_ids[:, -1] = self.pad_token_id
|
||||||
|
attention_mask[:, -1] = 0
|
||||||
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
||||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||||
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
||||||
@@ -281,6 +283,90 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||||
|
self, attn_implementation: str, fa_kwargs: bool = False
|
||||||
|
):
|
||||||
|
max_new_tokens = 30
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
dummy_input = inputs_dict[model_class.main_input_name]
|
||||||
|
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||||
|
dummy_input = dummy_input.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# make sure that all models have enough positions for generation
|
||||||
|
if hasattr(config, "max_position_embeddings"):
|
||||||
|
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
|
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||||
|
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||||
|
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||||
|
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||||
|
|
||||||
|
model = (
|
||||||
|
model_class.from_pretrained(
|
||||||
|
tmpdirname,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
attn_implementation=attn_implementation,
|
||||||
|
)
|
||||||
|
.to(torch_device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
|
||||||
|
# flatten
|
||||||
|
padfree_inputs_dict = {
|
||||||
|
"pixel_values": inputs_dict["pixel_values"],
|
||||||
|
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||||
|
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# add position_ids
|
||||||
|
vision_position_ids, deltas = model.model.get_rope_index(
|
||||||
|
input_ids=inputs_dict["input_ids"],
|
||||||
|
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||||
|
attention_mask=inputs_dict["attention_mask"],
|
||||||
|
) # [3, bs, padded-seq-len]
|
||||||
|
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||||
|
3, -1
|
||||||
|
) # [3, bs*padfree-len]
|
||||||
|
text_padfree_positions = torch.cat(
|
||||||
|
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||||
|
) # [1, bs*padfree-len]
|
||||||
|
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||||
|
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||||
|
:, None, :
|
||||||
|
]
|
||||||
|
|
||||||
|
if fa_kwargs:
|
||||||
|
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||||
|
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||||
|
max_length = cu_seq_lens.diff().max().item()
|
||||||
|
padfree_inputs_dict.update(
|
||||||
|
{
|
||||||
|
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||||
|
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||||
|
"max_length_q": max_length,
|
||||||
|
"max_length_k": max_length,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||||
|
res_padded = model(**inputs_dict, use_cache=False)
|
||||||
|
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||||
|
|
||||||
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
|
logits_padfree = res_padfree.logits[0]
|
||||||
|
|
||||||
|
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||||
|
# acceptable numerical instability
|
||||||
|
tol = torch.finfo(torch.bfloat16).eps
|
||||||
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||||
def test_feed_forward_chunking(self):
|
def test_feed_forward_chunking(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -4129,13 +4129,14 @@ class ModelTesterMixin:
|
|||||||
self.skipTest(reason="Model architecture does not support attentions")
|
self.skipTest(reason="Model architecture does not support attentions")
|
||||||
|
|
||||||
max_new_tokens = 30
|
max_new_tokens = 30
|
||||||
|
support_flag = {
|
||||||
|
"sdpa": "_supports_sdpa",
|
||||||
|
"flash_attention_2": "_supports_flash_attn",
|
||||||
|
"flash_attention_3": "_supports_flash_attn",
|
||||||
|
}
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
if not (
|
if not getattr(model_class, support_flag[attn_implementation]):
|
||||||
model_class._supports_flash_attn_2
|
|
||||||
if attn_implementation == "flash_attention_2"
|
|
||||||
else model_class._supports_flash_attn_3
|
|
||||||
):
|
|
||||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
@@ -4204,8 +4205,9 @@ class ModelTesterMixin:
|
|||||||
.to(torch_device)
|
.to(torch_device)
|
||||||
)
|
)
|
||||||
|
|
||||||
res_padded = model(**inputs_dict)
|
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||||
res_padfree = model(**padfree_inputs_dict)
|
res_padded = model(**inputs_dict, use_cache=False)
|
||||||
|
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||||
|
|
||||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||||
logits_padfree = res_padfree.logits[0]
|
logits_padfree = res_padfree.logits[0]
|
||||||
@@ -4215,6 +4217,16 @@ class ModelTesterMixin:
|
|||||||
tol = torch.finfo(torch.bfloat16).eps
|
tol = torch.finfo(torch.bfloat16).eps
|
||||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
|
||||||
|
# FIXME @raushan
|
||||||
|
@slow
|
||||||
|
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||||
|
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||||
|
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@mark.flash_attn_test
|
@mark.flash_attn_test
|
||||||
|
|||||||
Reference in New Issue
Block a user