[qwen2 vl] fix packing with all attentions (#39447)
* fix qwen2 vl packing in FA2 * why? delete! * qwen2-5-vl seems to work now * update * fix tests * start by adapting FA2 tests * add similar tests for sdpa/eager * address comments * why is this even in conditional model and not base model?
This commit is contained in:
committed by
GitHub
parent
e42681b48b
commit
344012b3a6
@@ -40,7 +40,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelO
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import auto_docstring, check_torch_load_is_safe, logging
|
||||
from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging
|
||||
from ...utils.hub import cached_file
|
||||
from .configuration_qwen2_5_omni import (
|
||||
Qwen2_5OmniAudioEncoderConfig,
|
||||
@@ -1424,6 +1424,7 @@ class Qwen2_5OmniAttention(nn.Module):
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
position_ids=position_ids, # pass positions for FA2
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1607,9 +1608,25 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
# the hard coded `3` is for temporal, height and width.
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||
elif position_ids.dim() == 2:
|
||||
elif position_ids.ndim == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||
# There are two scenarios when FA2-like packed masking might be activated.
|
||||
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||
text_position_ids = position_ids[0]
|
||||
position_ids = position_ids[1:]
|
||||
else:
|
||||
text_position_ids = position_ids[0]
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
@@ -1619,7 +1636,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"position_ids": position_ids,
|
||||
"position_ids": text_position_ids,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
@@ -1645,7 +1662,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
position_ids=text_position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@@ -1804,6 +1821,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
use_audio_in_video: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
video_second_per_grid: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
||||
r"""
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
@@ -1959,6 +1977,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
@@ -2146,9 +2165,25 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
# the hard coded `3` is for temporal, height and width.
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||
elif position_ids.dim() == 2:
|
||||
elif position_ids.ndim == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||
# There are two scenarios when FA2-like packed masking might be activated.
|
||||
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||
text_position_ids = position_ids[0]
|
||||
position_ids = position_ids[1:]
|
||||
else:
|
||||
text_position_ids = position_ids[0]
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
@@ -2158,7 +2193,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"position_ids": position_ids,
|
||||
"position_ids": text_position_ids,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
@@ -2184,7 +2219,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
position_ids=text_position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
|
||||
@@ -49,7 +49,9 @@ from ...modeling_flash_attention_utils import is_flash_attn_available
|
||||
from ...modeling_outputs import BaseModelOutput, ModelOutput
|
||||
from ...modeling_rope_utils import rope_config_validation
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import (
|
||||
TransformersKwargs,
|
||||
auto_docstring,
|
||||
check_torch_load_is_safe,
|
||||
logging,
|
||||
@@ -2259,6 +2261,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
use_audio_in_video: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
video_second_per_grid: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> Union[tuple, Qwen2_5OmniThinkerCausalLMOutputWithPast]:
|
||||
r"""
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
@@ -2414,6 +2417,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
@@ -710,6 +710,7 @@ class Qwen2_5_VLAttention(nn.Module):
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
position_ids=position_ids, # pass positions for FA2
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -878,9 +879,25 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
# the hard coded `3` is for temporal, height and width.
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||
elif position_ids.dim() == 2:
|
||||
elif position_ids.ndim == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||
# There are two scenarios when FA2-like packed masking might be activated.
|
||||
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||
text_position_ids = position_ids[0]
|
||||
position_ids = position_ids[1:]
|
||||
else:
|
||||
text_position_ids = position_ids[0]
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
@@ -890,7 +907,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"position_ids": position_ids,
|
||||
"position_ids": text_position_ids,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
@@ -916,7 +933,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
position_ids=text_position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@@ -1279,16 +1296,6 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
# Only apply conversion for floating point tensors (inverted masks)
|
||||
if attention_mask_tensor.dtype.is_floating_point:
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
@@ -1307,23 +1314,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask_tensor,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
||||
if cache_position is not None:
|
||||
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
else:
|
||||
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
|
||||
position_ids += delta.to(position_ids.device)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=None,
|
||||
@@ -1573,8 +1576,35 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
|
||||
model_inputs["position_ids"] = None
|
||||
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
||||
if position_ids is None:
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
||||
vision_positions, rope_deltas = self.model.get_rope_index(
|
||||
model_inputs.get("input_ids", None),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.model.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
elif "position_ids" in model_inputs:
|
||||
position_ids = model_inputs["position_ids"][None, ...]
|
||||
delta = self.model.rope_deltas
|
||||
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
|
||||
vision_positions = position_ids + delta.expand_as(position_ids)
|
||||
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
|
||||
|
||||
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
||||
if "position_ids" not in model_inputs:
|
||||
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
|
||||
else:
|
||||
text_positions = model_inputs["position_ids"][None, ...]
|
||||
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
||||
|
||||
if cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
|
||||
@@ -630,16 +630,6 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
# Only apply conversion for floating point tensors (inverted masks)
|
||||
if attention_mask_tensor.dtype.is_floating_point:
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
@@ -658,23 +648,19 @@ class Qwen2_5_VLModel(Qwen2VLModel):
|
||||
image_grid_thw,
|
||||
video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask_tensor,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = (
|
||||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
if cache_position is not None
|
||||
else 0
|
||||
)
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
||||
if cache_position is not None:
|
||||
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
else:
|
||||
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
|
||||
position_ids += delta.to(position_ids.device)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=None,
|
||||
@@ -848,8 +834,35 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Qwen2-5-VL position_ids are prepareed with rope_deltas in forward
|
||||
model_inputs["position_ids"] = None
|
||||
# Qwen2-5-VL position_ids are prepared with rope_deltas
|
||||
if position_ids is None:
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
if cache_position[0] == 0 or self.model.rope_deltas is None:
|
||||
vision_positions, rope_deltas = self.model.get_rope_index(
|
||||
model_inputs.get("input_ids", None),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.model.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
elif "position_ids" in model_inputs:
|
||||
position_ids = model_inputs["position_ids"][None, ...]
|
||||
delta = self.model.rope_deltas
|
||||
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
|
||||
vision_positions = position_ids + delta.expand_as(position_ids)
|
||||
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
|
||||
|
||||
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
||||
if "position_ids" not in model_inputs:
|
||||
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
|
||||
else:
|
||||
text_positions = model_inputs["position_ids"][None, ...]
|
||||
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
||||
|
||||
if cache_position[0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
|
||||
@@ -558,6 +558,7 @@ class Qwen2VLAttention(nn.Module):
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
sliding_window=self.sliding_window,
|
||||
position_ids=position_ids, # pass positions for FA2
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -853,9 +854,25 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
# the hard coded `3` is for temporal, height and width.
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
|
||||
elif position_ids.dim() == 2:
|
||||
elif position_ids.ndim == 2:
|
||||
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
|
||||
|
||||
# NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
|
||||
# where each dim indicates visual spatial positions for temporal/height/width grids.
|
||||
# There are two scenarios when FA2-like packed masking might be activated.
|
||||
# 1. User specifically passed packed `position_ids` and no attention mask.
|
||||
# In this case we expect the useer to create correct position ids for all 3 grids
|
||||
# and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
|
||||
# 2. User runs forward with no attention mask and no position ids. In this case, position ids
|
||||
# are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
|
||||
# prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
|
||||
# text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
|
||||
if position_ids.ndim == 3 and position_ids.shape[0] == 4:
|
||||
text_position_ids = position_ids[0]
|
||||
position_ids = position_ids[1:]
|
||||
else:
|
||||
text_position_ids = position_ids[0]
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
# Prepare mask arguments
|
||||
@@ -865,7 +882,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
"attention_mask": attention_mask,
|
||||
"cache_position": cache_position,
|
||||
"past_key_values": past_key_values,
|
||||
"position_ids": position_ids,
|
||||
"position_ids": text_position_ids,
|
||||
}
|
||||
# Create the masks
|
||||
causal_mask_mapping = {
|
||||
@@ -891,7 +908,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel):
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
||||
position_ids=position_ids,
|
||||
position_ids=text_position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
@@ -1217,44 +1234,22 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
||||
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
||||
|
||||
if position_ids is None:
|
||||
attention_mask_tensor = (
|
||||
attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
|
||||
)
|
||||
if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
|
||||
attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
|
||||
# Only apply conversion for floating point tensors (inverted masks)
|
||||
if attention_mask_tensor.dtype.is_floating_point:
|
||||
attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
|
||||
attention_mask_tensor = (1.0 - attention_mask_tensor).int()
|
||||
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||
(input_ids is not None and input_ids.shape[1] != 1)
|
||||
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||
)
|
||||
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
)
|
||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
|
||||
if self.rope_deltas is None or cache_position is None or cache_position[0] == 0:
|
||||
position_ids, rope_deltas = self.get_rope_index(
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask_tensor
|
||||
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||
)
|
||||
self.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
else:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
||||
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
||||
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
||||
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
|
||||
if cache_position is not None:
|
||||
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||
else:
|
||||
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
|
||||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
||||
delta = delta.to(position_ids.device)
|
||||
position_ids = position_ids.add(delta)
|
||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
||||
position_ids += delta.to(position_ids.device)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=None,
|
||||
@@ -1465,7 +1460,41 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
||||
)
|
||||
|
||||
# Qwen2-VL position_ids are prepareed with rope_deltas in forward
|
||||
model_inputs["position_ids"] = None
|
||||
if position_ids is None:
|
||||
# Calculate RoPE index once per generation in the pre-fill stage only.
|
||||
# When compiling, we can't check tensor values thus we check only input length
|
||||
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
|
||||
# models currently cannot do asssisted decoding
|
||||
prefill_compiled_stage = is_torchdynamo_compiling() and (
|
||||
(input_ids is not None and input_ids.shape[1] != 1)
|
||||
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
|
||||
)
|
||||
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
|
||||
(cache_position is not None and cache_position[0] == 0)
|
||||
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
||||
)
|
||||
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.model.rope_deltas is None:
|
||||
vision_positions, rope_deltas = self.model.get_rope_index(
|
||||
model_inputs.get("input_ids", None),
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
self.model.rope_deltas = rope_deltas
|
||||
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||
elif "position_ids" in model_inputs:
|
||||
position_ids = model_inputs["position_ids"][None, ...]
|
||||
delta = self.model.rope_deltas
|
||||
delta = delta.repeat_interleave(position_ids.shape[1] // delta.shape[0], dim=0)
|
||||
vision_positions = position_ids + delta.expand_as(position_ids)
|
||||
vision_positions = vision_positions.expand(3, vision_positions.shape[1], -1)
|
||||
|
||||
# Concatenate "text + vision" positions into [4, bs, seq-len]
|
||||
if "position_ids" not in model_inputs:
|
||||
text_positions = torch.arange(input_ids, device=input_ids.device)[None, None, :]
|
||||
else:
|
||||
text_positions = model_inputs["position_ids"][None, ...]
|
||||
model_inputs["position_ids"] = torch.cat([text_positions, vision_positions], dim=0)
|
||||
|
||||
if model_inputs["cache_position"][0] != 0:
|
||||
model_inputs["pixel_values"] = None
|
||||
|
||||
@@ -332,6 +332,92 @@ class Qwen2_5OmniThinkerForConditionalGenerationModelTest(ModelTesterMixin, Gene
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
"input_features": inputs_dict["input_features"],
|
||||
"feature_attention_mask": inputs_dict["feature_attention_mask"],
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||
}
|
||||
|
||||
# add position_ids
|
||||
vision_position_ids, deltas = model.get_rope_index(
|
||||
input_ids=inputs_dict["input_ids"],
|
||||
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
audio_seqlens=torch.sum(inputs_dict["feature_attention_mask"], dim=1),
|
||||
) # [3, bs, padded-seq-len]
|
||||
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||
3, -1
|
||||
) # [3, bs*padfree-len]
|
||||
text_padfree_positions = torch.cat(
|
||||
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||
) # [1, bs*padfree-len]
|
||||
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||
:, None, :
|
||||
]
|
||||
|
||||
if fa_kwargs:
|
||||
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
padfree_inputs_dict.update(
|
||||
{
|
||||
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"max_length_q": max_length,
|
||||
"max_length_k": max_length,
|
||||
}
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@unittest.skip("Cannot do contrastive generation, has custom `generate()`")
|
||||
def test_contrastive_generate(self):
|
||||
pass
|
||||
|
||||
@@ -325,6 +325,89 @@ class Qwen2_5_VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
)
|
||||
self.assertIsNotNone(outputs)
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||
}
|
||||
|
||||
# add position_ids
|
||||
vision_position_ids, deltas = model.model.get_rope_index(
|
||||
input_ids=inputs_dict["input_ids"],
|
||||
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
) # [3, bs, padded-seq-len]
|
||||
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||
3, -1
|
||||
) # [3, bs*padfree-len]
|
||||
text_padfree_positions = torch.cat(
|
||||
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||
) # [1, bs*padfree-len]
|
||||
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||
:, None, :
|
||||
]
|
||||
|
||||
if fa_kwargs:
|
||||
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
padfree_inputs_dict.update(
|
||||
{
|
||||
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"max_length_q": max_length,
|
||||
"max_length_k": max_length,
|
||||
}
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
@@ -168,6 +169,7 @@ class Qwen2VLVisionText2TextModelTester:
|
||||
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
|
||||
|
||||
input_ids[:, -1] = self.pad_token_id
|
||||
attention_mask[:, -1] = 0
|
||||
input_ids[input_ids == self.video_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.image_token_id] = self.pad_token_id
|
||||
input_ids[input_ids == self.vision_start_token_id] = self.pad_token_id
|
||||
@@ -281,6 +283,90 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
generation_output.logits[0], forward_output.logits[:, -1, :], rtol=1e-4, atol=1e-4
|
||||
)
|
||||
|
||||
def flash_attention_padding_matches_padding_free_with_position_ids(
|
||||
self, attn_implementation: str, fa_kwargs: bool = False
|
||||
):
|
||||
max_new_tokens = 30
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.float16]:
|
||||
dummy_input = dummy_input.to(torch.bfloat16)
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
if 0 in inputs_dict["attention_mask"][:, -1]:
|
||||
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
|
||||
dummy_attention_mask = inputs_dict["attention_mask"]
|
||||
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
|
||||
|
||||
model = (
|
||||
model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
# flatten
|
||||
padfree_inputs_dict = {
|
||||
"pixel_values": inputs_dict["pixel_values"],
|
||||
"image_grid_thw": inputs_dict["image_grid_thw"],
|
||||
"input_ids": inputs_dict["input_ids"][dummy_attention_mask.bool()].unsqueeze(0),
|
||||
}
|
||||
|
||||
# add position_ids
|
||||
vision_position_ids, deltas = model.model.get_rope_index(
|
||||
input_ids=inputs_dict["input_ids"],
|
||||
image_grid_thw=inputs_dict["image_grid_thw"],
|
||||
attention_mask=inputs_dict["attention_mask"],
|
||||
) # [3, bs, padded-seq-len]
|
||||
vision_padfree_positions = vision_position_ids[:, dummy_attention_mask.bool()].view(
|
||||
3, -1
|
||||
) # [3, bs*padfree-len]
|
||||
text_padfree_positions = torch.cat(
|
||||
[torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]
|
||||
) # [1, bs*padfree-len]
|
||||
text_padfree_positions = text_padfree_positions.long().unsqueeze(0).to(torch_device)
|
||||
padfree_inputs_dict["position_ids"] = torch.cat([text_padfree_positions, vision_padfree_positions])[
|
||||
:, None, :
|
||||
]
|
||||
|
||||
if fa_kwargs:
|
||||
cu_seq_lens = [0] + dummy_attention_mask.sum(1).tolist()
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, device=torch_device)
|
||||
max_length = cu_seq_lens.diff().max().item()
|
||||
padfree_inputs_dict.update(
|
||||
{
|
||||
"cu_seq_lens_q": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"cu_seq_lens_k": cu_seq_lens.cumsum(-1).to(dtype=torch.int32),
|
||||
"max_length_q": max_length,
|
||||
"max_length_k": max_length,
|
||||
}
|
||||
)
|
||||
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
|
||||
torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
|
||||
# acceptable numerical instability
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
@unittest.skip(reason="Feedforward chunking is not yet supported")
|
||||
def test_feed_forward_chunking(self):
|
||||
pass
|
||||
|
||||
@@ -4129,13 +4129,14 @@ class ModelTesterMixin:
|
||||
self.skipTest(reason="Model architecture does not support attentions")
|
||||
|
||||
max_new_tokens = 30
|
||||
support_flag = {
|
||||
"sdpa": "_supports_sdpa",
|
||||
"flash_attention_2": "_supports_flash_attn",
|
||||
"flash_attention_3": "_supports_flash_attn",
|
||||
}
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not (
|
||||
model_class._supports_flash_attn_2
|
||||
if attn_implementation == "flash_attention_2"
|
||||
else model_class._supports_flash_attn_3
|
||||
):
|
||||
if not getattr(model_class, support_flag[attn_implementation]):
|
||||
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -4204,8 +4205,9 @@ class ModelTesterMixin:
|
||||
.to(torch_device)
|
||||
)
|
||||
|
||||
res_padded = model(**inputs_dict)
|
||||
res_padfree = model(**padfree_inputs_dict)
|
||||
# We need to do simple forward without cache in roder to trigger packed SDPA/FLEX/EAGER path
|
||||
res_padded = model(**inputs_dict, use_cache=False)
|
||||
res_padfree = model(**padfree_inputs_dict, use_cache=False)
|
||||
|
||||
logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
|
||||
logits_padfree = res_padfree.logits[0]
|
||||
@@ -4215,6 +4217,16 @@ class ModelTesterMixin:
|
||||
tol = torch.finfo(torch.bfloat16).eps
|
||||
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)
|
||||
|
||||
# Mark slow for now as it is failing for all multimodals/non-transformer arch models and a few LLMs
|
||||
# FIXME @raushan
|
||||
@slow
|
||||
def test_eager_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="eager")
|
||||
|
||||
@slow
|
||||
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
|
||||
self.flash_attention_padding_matches_padding_free_with_position_ids(attn_implementation="sdpa")
|
||||
|
||||
@require_flash_attn
|
||||
@require_torch_gpu
|
||||
@mark.flash_attn_test
|
||||
|
||||
Reference in New Issue
Block a user