fix: Qwen2-VL generate with inputs_embeds (#35466)
* fix: Qwen2-VL generate with inputs_embeds * change: optional input_ids in get_rope_index
This commit is contained in:
@@ -32,13 +32,8 @@ from torch.nn import CrossEntropyLoss, LayerNorm
|
|||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import (
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
AttentionMaskConverter,
|
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
||||||
)
|
|
||||||
from ...modeling_outputs import (
|
|
||||||
BaseModelOutputWithPast,
|
|
||||||
ModelOutput,
|
|
||||||
)
|
|
||||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -1420,7 +1415,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||||||
|
|
||||||
def get_rope_index(
|
def get_rope_index(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
@@ -1550,7 +1545,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
|
||||||
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
|
||||||
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
||||||
else:
|
else:
|
||||||
@@ -1676,7 +1671,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin):
|
|||||||
attention_mask = attention_mask.to(inputs_embeds.device)
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||||
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
# calculate RoPE index once per generation in the pre-fill stage only
|
# calculate RoPE index once per generation in the pre-fill stage only
|
||||||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
|
||||||
position_ids, rope_deltas = self.get_rope_index(
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
|
|||||||
@@ -1615,9 +1615,7 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
# There are a few exception patterns in this test:
|
# There are a few exception patterns in this test:
|
||||||
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
|
# 1 - Some models can't generate without `input_ids`, when `inputs_embeds` are passed
|
||||||
requires_inputs_ids = any(
|
requires_inputs_ids = any(model_name in model_class.__name__.lower() for model_name in ["idefics"])
|
||||||
model_name in model_class.__name__.lower() for model_name in ["idefics", "qwen2vl"]
|
|
||||||
)
|
|
||||||
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
# 2 - Complex `inputs_embeds` computation, i.e. the correct computation of inputs embeds is more complex
|
||||||
# than calling the embedding layer with `input_ids`. Subcases of this exception:
|
# than calling the embedding layer with `input_ids`. Subcases of this exception:
|
||||||
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
# 2.A - Ignore `scale_embedding`, if the model supports it (it is controlled by a model-dependent flag)
|
||||||
|
|||||||
Reference in New Issue
Block a user