From d3b8627b56caa7ca8fac113c9f28d0256db0194d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 1 Aug 2025 10:01:06 +0200 Subject: [PATCH] [VLMs] split out "get placeholder mask" to helper (#39777) * batch upidate all models * update * forgot about llava onevision * update * fix tests * delete file * typo * fix emu3 once and forever * update cohere2 vision as well --- src/transformers/models/aria/modeling_aria.py | 44 +++++---- src/transformers/models/aria/modular_aria.py | 20 +--- .../models/aya_vision/modeling_aya_vision.py | 46 ++++++---- .../models/aya_vision/modular_aya_vision.py | 22 +---- .../models/blip_2/modeling_blip_2.py | 50 ++++++---- .../models/chameleon/modeling_chameleon.py | 44 +++++---- .../cohere2_vision/modeling_cohere2_vision.py | 37 ++++++-- .../cohere2_vision/modular_cohere2_vision.py | 13 +-- .../deepseek_vl/modeling_deepseek_vl.py | 36 ++++++-- .../modeling_deepseek_vl_hybrid.py | 24 +++++ src/transformers/models/emu3/modeling_emu3.py | 37 ++++++-- src/transformers/models/emu3/modular_emu3.py | 37 ++++++-- src/transformers/models/fuyu/modeling_fuyu.py | 40 +++++--- .../models/gemma3/modeling_gemma3.py | 54 +++++------ .../models/gemma3/modular_gemma3.py | 23 +---- .../models/gemma3n/modeling_gemma3n.py | 89 ++++++++++-------- .../models/gemma3n/modular_gemma3n.py | 82 ++++++++++------- .../models/glm4v/modeling_glm4v.py | 82 +++++++++-------- .../models/glm4v/modular_glm4v.py | 42 +-------- .../models/got_ocr2/modeling_got_ocr2.py | 43 +++++---- .../models/got_ocr2/modular_got_ocr2.py | 19 +--- .../instructblip/modeling_instructblip.py | 62 +++++++------ .../modeling_instructblipvideo.py | 50 ++++++---- .../modular_instructblipvideo.py | 35 ++++--- .../models/internvl/modeling_internvl.py | 53 ++++++----- .../models/internvl/modular_internvl.py | 22 +---- .../models/janus/modeling_janus.py | 36 ++++++-- .../models/janus/modular_janus.py | 36 ++++++-- .../models/llama4/modeling_llama4.py | 48 ++++++---- .../models/llava/modeling_llava.py | 49 ++++++---- .../models/llava_next/modeling_llava_next.py | 48 ++++++---- .../modeling_llava_next_video.py | 92 +++++++++++-------- .../modular_llava_next_video.py | 92 +++++++++++-------- .../modeling_llava_onevision.py | 84 +++++++++-------- .../modular_llava_onevision.py | 44 ++------- .../models/mistral3/modeling_mistral3.py | 49 ++++++---- .../models/mistral3/modular_mistral3.py | 25 +---- .../models/paligemma/modeling_paligemma.py | 46 ++++++---- .../perception_lm/modeling_perception_lm.py | 68 ++++++++++---- .../perception_lm/modular_perception_lm.py | 64 +++++++++---- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 82 +++++++++++------ .../qwen2_5_omni/modular_qwen2_5_omni.py | 82 +++++++++++------ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 85 +++++++++-------- .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 45 ++------- .../models/qwen2_vl/modeling_qwen2_vl.py | 86 +++++++++-------- .../video_llava/modeling_video_llava.py | 83 ++++++++++------- .../models/vipllava/modeling_vipllava.py | 46 ++++++---- .../models/vipllava/modular_vipllava.py | 22 +---- .../deepseek_vl/test_modeling_deepseek_vl.py | 5 +- tests/models/emu3/test_modeling_emu3.py | 11 ++- tests/models/janus/test_modeling_janus.py | 2 +- .../test_modeling_llava_onevision.py | 3 + 52 files changed, 1370 insertions(+), 1069 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c1278bee03..deda70ec55 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -978,6 +978,30 @@ class AriaModel(AriaPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1007,29 +1031,15 @@ class AriaModel(AriaPreTrainedModel): # 2. Merge text and images if pixel_values is not None and inputs_embeds.shape[1] != 1: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_features = self.get_image_features( pixel_values=pixel_values, pixel_mask=pixel_mask, vision_feature_layer=self.config.vision_feature_layer, ) - n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] - n_image_features = n_images * n_features_per_image - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self._get_image_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 4a19c3387d..d126c601a5 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1431,29 +1431,15 @@ class AriaModel(LlavaModel): # 2. Merge text and images if pixel_values is not None and inputs_embeds.shape[1] != 1: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_features = self.get_image_features( pixel_values=pixel_values, pixel_mask=pixel_mask, vision_feature_layer=self.config.vision_feature_layer, ) - n_images, n_features_per_image = image_features.shape[0], image_features.shape[1] - n_image_features = n_images * n_features_per_image - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self._get_image_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/aya_vision/modeling_aya_vision.py b/src/transformers/models/aya_vision/modeling_aya_vision.py index dd922c63ce..e899c1ebef 100644 --- a/src/transformers/models/aya_vision/modeling_aya_vision.py +++ b/src/transformers/models/aya_vision/modeling_aya_vision.py @@ -32,7 +32,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.generic import check_model_inputs from ..auto import AutoModel from .configuration_aya_vision import AyaVisionConfig @@ -242,6 +242,30 @@ class AyaVisionModel(AyaVisionPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @check_model_inputs @auto_docstring def forward( @@ -279,24 +303,10 @@ class AyaVisionModel(AyaVisionPreTrainedModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/aya_vision/modular_aya_vision.py b/src/transformers/models/aya_vision/modular_aya_vision.py index 5a7b0950cd..4d18b5806c 100644 --- a/src/transformers/models/aya_vision/modular_aya_vision.py +++ b/src/transformers/models/aya_vision/modular_aya_vision.py @@ -32,7 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torchdynamo_compiling, logging +from ...utils import auto_docstring, logging from ...utils.generic import check_model_inputs from .configuration_aya_vision import AyaVisionConfig @@ -200,24 +200,10 @@ class AyaVisionModel(LlavaModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum(dim=1).sum(dim=0) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 66c7c95e6b..169d35081f 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -1455,6 +1455,21 @@ class Blip2Model(Blip2PreTrainedModel): return query_outputs + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + @auto_docstring def forward( self, @@ -1545,16 +1560,8 @@ class Blip2Model(Blip2PreTrainedModel): if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( special_image_mask, language_model_inputs ) @@ -1938,6 +1945,21 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): return language_model_inputs, vision_outputs, query_outputs return language_model_inputs + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + @auto_docstring def forward( self, @@ -2042,16 +2064,8 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter( special_image_mask, language_model_inputs ) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index dd0d0a5582..6a3f8db749 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -35,7 +35,6 @@ from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, - is_torchdynamo_compiling, logging, ) from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig @@ -888,6 +887,30 @@ class ChameleonModel(ChameleonPreTrainedModel): vision_embeddings = self.get_input_embeddings()(image_tokens) return vision_embeddings + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @auto_docstring def forward( self, @@ -924,23 +947,10 @@ class ChameleonModel(ChameleonPreTrainedModel): inputs_embeds = self.embed_tokens(input_ids) if pixel_values is not None: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - - n_image_tokens_in_text = special_image_mask.sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - image_embeds = self.get_image_features(pixel_values) - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_embeds.numel(): - n_image_features = image_embeds.shape[0] * image_embeds.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" - ) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # torch.jit.trace() doesn't support cache objects in the output diff --git a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py index 91f84cdd6b..36a07c3410 100644 --- a/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/modeling_cohere2_vision.py @@ -197,6 +197,30 @@ class Cohere2VisionModel(Cohere2VisionPreTrainedModel): image_features = self.multi_modal_projector(selected_image_feature) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @check_model_inputs @auto_docstring def forward( @@ -225,16 +249,9 @@ class Cohere2VisionModel(Cohere2VisionPreTrainedModel): if pixel_values is not None: image_features = self.get_image_features(pixel_values, image_num_patches=image_num_patches) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/cohere2_vision/modular_cohere2_vision.py b/src/transformers/models/cohere2_vision/modular_cohere2_vision.py index 90cf7defe7..c141c02c7c 100644 --- a/src/transformers/models/cohere2_vision/modular_cohere2_vision.py +++ b/src/transformers/models/cohere2_vision/modular_cohere2_vision.py @@ -145,16 +145,9 @@ class Cohere2VisionModel(AyaVisionModel): if pixel_values is not None: image_features = self.get_image_features(pixel_values, image_num_patches=image_num_patches) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py index 60ca3394fa..c113d41e75 100644 --- a/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py +++ b/src/transformers/models/deepseek_vl/modeling_deepseek_vl.py @@ -177,6 +177,30 @@ class DeepseekVLModel(DeepseekVLPreTrainedModel): image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -200,18 +224,12 @@ class DeepseekVLModel(DeepseekVLPreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - if input_ids is None: - image_attention_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_attention_mask = image_attention_mask.all(-1) - else: - image_attention_mask = input_ids == self.config.image_token_id - - image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_attention_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) lm_output = self.language_model( diff --git a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py index 1910c5659c..23ce4c061a 100644 --- a/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +++ b/src/transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py @@ -285,6 +285,30 @@ class DeepseekVLHybridModel(DeepseekVLHybridPreTrainedModel): images_embeds = self.aligner(vision_encodings, high_res_vision_encodings) return images_embeds + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring(custom_args=DEEPSEEK_VL_COMMON_CUSTOM_ARGS) def forward( diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 182afe6b90..b5ee232439 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1383,6 +1383,30 @@ class Emu3Model(Emu3PreTrainedModel): image = self.vqmodel.decode(image_tokens) return image + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1415,16 +1439,9 @@ class Emu3Model(Emu3PreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_sizes) image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 7bd59d1fae..e8e66fc853 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -965,6 +965,30 @@ class Emu3Model(Emu3PreTrainedModel): image = self.vqmodel.decode(image_tokens) return image + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.vocabulary_mapping.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -997,16 +1021,9 @@ class Emu3Model(Emu3PreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_sizes) image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.vocabulary_mapping.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index ff163aacd4..409333a8c6 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -147,6 +147,30 @@ class FuyuModel(FuyuPreTrainedModel): ] return patch_embeddings + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @auto_docstring def forward( self, @@ -200,18 +224,10 @@ class FuyuModel(FuyuPreTrainedModel): if image_patches is not None: patch_embeddings = self.get_image_features(image_patches) - patch_embeddings = torch.cat(patch_embeddings, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - patch_embeddings = patch_embeddings.to(inputs_embeds.device, inputs_embeds.dtype) + patch_embeddings = torch.cat(patch_embeddings, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_tokens( + input_ids, inputs_embeds=inputs_embeds, image_features=patch_embeddings + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings) outputs = self.language_model( diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 76eb0b697b..92fa45a663 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -38,14 +38,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import check_model_inputs from ..auto import AutoModel from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig @@ -805,6 +798,30 @@ class Gemma3Model(Gemma3PreTrainedModel): image_features = self.multi_modal_projector(vision_outputs) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -880,25 +897,10 @@ class Gemma3Model(Gemma3PreTrainedModel): # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # It may already have been prepared by e.g. `generate` diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index c2ad52f108..0ca7505c56 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -31,7 +31,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutpu from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -805,25 +805,10 @@ class Gemma3Model(PaliGemmaModel): # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # It may already have been prepared by e.g. `generate` diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 3cf07819be..3430c45fb0 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -39,14 +39,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig @@ -1968,6 +1961,48 @@ class Gemma3nModel(Gemma3nPreTrainedModel): vision_outputs *= self.config.vision_config.hidden_size**0.5 return self.embed_vision(inputs_embeds=vision_outputs) + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + audio_features: torch.FloatTensor, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_audio_mask = input_ids == self.config.audio_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}" + ) + + n_audio_tokens = special_audio_mask.sum() + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if audio_features is not None and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): + raise ValueError( + f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {audio_features.shape[0] * audio_features.shape[1]}" + ) + + return special_image_mask, special_audio_mask + @can_return_tuple def forward( self, @@ -2054,23 +2089,10 @@ class Gemma3nModel(Gemma3nPreTrainedModel): # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text and " - f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings." - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Merge text and audio @@ -2091,23 +2113,10 @@ class Gemma3nModel(Gemma3nPreTrainedModel): extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) audio_features = torch.cat((audio_features, extra_padding_features), dim=1) - - if input_ids is None: - special_audio_mask = inputs_embeds == self.embed_audio( - input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): - audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of audio input features does not match number of special audio tokens in the input text. " - f"Got {audio_tokens_in_text} audio tokens in the text and " - f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings." - ) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) outputs = self.language_model( diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index 2716e9bdfe..fd40253553 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -31,7 +31,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( @@ -2259,6 +2259,48 @@ class Gemma3nModel(PaliGemmaModel): vision_outputs *= self.config.vision_config.hidden_size**0.5 return self.embed_vision(inputs_embeds=vision_outputs) + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor, + audio_features: torch.FloatTensor, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_audio_mask = input_ids == self.config.audio_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}" + ) + + n_audio_tokens = special_audio_mask.sum() + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if audio_features is not None and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): + raise ValueError( + f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {audio_features.shape[0] * audio_features.shape[1]}" + ) + + return special_image_mask, special_audio_mask + @can_return_tuple def forward( self, @@ -2345,23 +2387,10 @@ class Gemma3nModel(PaliGemmaModel): # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text and " - f"{image_features.shape[0] * image_features.shape[1]} tokens from image embeddings." - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Merge text and audio @@ -2382,23 +2411,10 @@ class Gemma3nModel(PaliGemmaModel): extra_padding_features = audio_padding_embs.expand(audio_batch_size, extra_padding_tokens, audio_embed_dim) audio_features = torch.cat((audio_features, extra_padding_features), dim=1) - - if input_ids is None: - special_audio_mask = inputs_embeds == self.embed_audio( - input_ids=torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - else: - special_audio_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_audio_mask].numel() != audio_features.numel(): - audio_tokens_in_text = (special_audio_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of audio input features does not match number of special audio tokens in the input text. " - f"Got {audio_tokens_in_text} audio tokens in the text and " - f"{audio_features.shape[0] * audio_features.shape[1]} tokens from audio embeddings." - ) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) outputs = self.language_model( diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 9e0e4d4841..1eb28d3201 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -1182,6 +1182,46 @@ class Glm4vModel(Glm4vPreTrainedModel): image_embeds = torch.split(image_embeds, split_sizes) return image_embeds + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + @auto_docstring @can_return_tuple def forward( @@ -1224,48 +1264,14 @@ class Glm4vModel(Glm4vPreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - n_image_tokens = image_mask.sum() - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.image_token_id - - n_video_tokens = video_mask.sum() - n_video_features = video_embeds.shape[0] - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index d839328f5a..df32beb5da 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -1228,48 +1228,14 @@ class Glm4vModel(Qwen2_5_VLModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - n_image_tokens = image_mask.sum() - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.image_token_id - - n_video_tokens = video_mask.sum() - n_video_features = video_embeds.shape[0] - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 0d55eb3abc..d9f057b737 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -573,6 +573,30 @@ class GotOcr2Model(GotOcr2PreTrainedModel): image_outputs = self.vision_tower(pixel_values).last_hidden_state return self.multi_modal_projector(image_outputs) + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -603,24 +627,11 @@ class GotOcr2Model(GotOcr2PreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/got_ocr2/modular_got_ocr2.py b/src/transformers/models/got_ocr2/modular_got_ocr2.py index 9e017659e4..f1ec914bf4 100644 --- a/src/transformers/models/got_ocr2/modular_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modular_got_ocr2.py @@ -349,24 +349,11 @@ class GotOcr2Model(LlavaModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - image_features = self.get_image_features(pixel_values=pixel_values.to(inputs_embeds.dtype)) - n_image_features = image_features.shape[0] * image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index 72482558d7..49943140fe 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -1195,6 +1195,21 @@ class InstructBlipModel(InstructBlipPreTrainedModel): if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1278,21 +1293,15 @@ class InstructBlipModel(InstructBlipPreTrainedModel): ) query_output = query_outputs[0][:, : query_tokens.size(1), :] - # step 3: use the language model, conditioned on the query outputs and the prompt - language_model_inputs = self.language_projection(query_output) if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) - special_image_mask = input_ids == self.config.image_token_id if attention_mask is None: attention_mask = torch.ones_like(input_ids) - else: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + # step 3: use the language model, conditioned on the query outputs and the prompt + language_model_inputs = self.language_projection(query_output) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: @@ -1461,6 +1470,21 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati return language_model_inputs, vision_outputs, query_outputs return language_model_inputs + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1567,16 +1591,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: @@ -1678,16 +1694,8 @@ class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, Generati if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 5d3aeec452..adf9a9ced6 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -1191,6 +1191,21 @@ class InstructBlipVideoModel(InstructBlipVideoPreTrainedModel): if hasattr(self.language_model, "_hf_hook"): self.language_model._hf_hook.io_same_device = True # For `generate` compatibility + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1433,6 +1448,21 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel """ pass + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1534,16 +1564,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: @@ -1645,16 +1667,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} diff --git a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py index e7bcfaba82..9588061e7e 100644 --- a/src/transformers/models/instructblipvideo/modular_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modular_instructblipvideo.py @@ -372,6 +372,21 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera ): pass + def get_placeholder_mask(self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.video_token_id + + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask + def forward( self, pixel_values: torch.FloatTensor, @@ -471,16 +486,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) if self.config.use_decoder_only_language_model: @@ -582,16 +589,8 @@ class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGenera if attention_mask is None: attention_mask = torch.ones_like(input_ids) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs) inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask} diff --git a/src/transformers/models/internvl/modeling_internvl.py b/src/transformers/models/internvl/modeling_internvl.py index 8e3963cfab..ab0b30d839 100644 --- a/src/transformers/models/internvl/modeling_internvl.py +++ b/src/transformers/models/internvl/modeling_internvl.py @@ -36,14 +36,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - ModelOutput, - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - torch_int, -) +from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_int from ..auto import AutoModel from .configuration_internvl import InternVLConfig, InternVLVisionConfig @@ -638,6 +631,30 @@ class InternVLModel(InternVLPreTrainedModel): vision_features = self.multi_modal_projector(vision_features) return vision_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -683,24 +700,10 @@ class InternVLModel(InternVLPreTrainedModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/internvl/modular_internvl.py b/src/transformers/models/internvl/modular_internvl.py index 67d864eec7..28ce0b06d9 100644 --- a/src/transformers/models/internvl/modular_internvl.py +++ b/src/transformers/models/internvl/modular_internvl.py @@ -29,7 +29,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging, torch_int +from ...utils import auto_docstring, can_return_tuple, logging, torch_int from ..clip.modeling_clip import CLIPMLP from ..janus.modeling_janus import JanusVisionAttention from ..llama.modeling_llama import LlamaRMSNorm @@ -616,24 +616,10 @@ class InternVLModel(LlavaModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/janus/modeling_janus.py b/src/transformers/models/janus/modeling_janus.py index ebdc2f23ea..794514867a 100644 --- a/src/transformers/models/janus/modeling_janus.py +++ b/src/transformers/models/janus/modeling_janus.py @@ -1046,6 +1046,30 @@ class JanusModel(JanusPreTrainedModel): image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -1069,18 +1093,12 @@ class JanusModel(JanusPreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - if input_ids is None: - image_attention_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_attention_mask = image_attention_mask.all(-1) - else: - image_attention_mask = input_ids == self.config.image_token_id - - image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_attention_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) lm_output = self.language_model( diff --git a/src/transformers/models/janus/modular_janus.py b/src/transformers/models/janus/modular_janus.py index 11b0848620..ad4ffd196e 100644 --- a/src/transformers/models/janus/modular_janus.py +++ b/src/transformers/models/janus/modular_janus.py @@ -906,6 +906,30 @@ class JanusModel(JanusPreTrainedModel): image_embeds = self.aligner(image_embeds.last_hidden_state) return image_embeds + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + n_image_features = image_features.shape[0] * image_features.shape[1] + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -929,18 +953,12 @@ class JanusModel(JanusPreTrainedModel): inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values is not None: - if input_ids is None: - image_attention_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_attention_mask = image_attention_mask.all(-1) - else: - image_attention_mask = input_ids == self.config.image_token_id - - image_attention_mask = image_attention_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = self.get_image_features(pixel_values) image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1]) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_attention_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features) lm_output = self.language_model( diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index d04d443ec8..9d70d8eca0 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -1201,6 +1201,29 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): hidden_state = image_outputs.last_hidden_state return hidden_state + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + return special_image_mask + @auto_docstring def forward( self, @@ -1286,25 +1309,12 @@ class Llama4ForConditionalGeneration(Llama4PreTrainedModel, GenerationMixin): ) vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if n_image_tokens != projected_vision_flat.size(0): - raise ValueError( - f"Mismatch: final_mask wants {n_image_tokens} embeddings, " - f"but multi_modal_projector returned {projected_vision_flat.size(0)}" - ) - projected_vision_flat = projected_vision_flat.to(inputs_embeds.device, inputs_embeds.dtype) + projected_vision_flat = self.multi_modal_projector(vision_flat).to( + inputs_embeds.device, inputs_embeds.dtype + ) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=projected_vision_flat + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, projected_vision_flat) outputs = self.language_model( diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e5145554dc..b2fec92999 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -28,7 +28,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_llava import LlavaConfig @@ -218,6 +218,30 @@ class LlavaModel(LlavaPreTrainedModel): image_features = list(image_features) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -265,25 +289,10 @@ class LlavaModel(LlavaPreTrainedModel): vision_feature_select_strategy=vision_feature_select_strategy, image_sizes=image_sizes, ) - image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 94f03925b8..241acaa398 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -30,7 +30,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_llava_next import LlavaNextConfig @@ -426,6 +426,29 @@ class LlavaNextModel(LlavaNextPreTrainedModel): ) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -479,25 +502,10 @@ class LlavaNextModel(LlavaNextPreTrainedModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index dce37a4dd9..721da01a9b 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -35,7 +35,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_llava_next_video import LlavaNextVideoConfig @@ -476,6 +476,46 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): ) return image_features + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + @can_return_tuple @auto_docstring def forward( @@ -523,35 +563,20 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if pixel_values is not None and pixel_values.size(0) > 0: + if pixel_values is not None: image_features = self.get_image_features( pixel_values, image_sizes, vision_feature_layer=self.vision_feature_layer, vision_feature_select_strategy=self.vision_feature_select_strategy, ) - image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: + if pixel_values_videos is not None: video_features = self.get_video_features( pixel_values_videos, vision_feature_layer=self.vision_feature_layer, @@ -561,25 +586,12 @@ class LlavaNextVideoModel(LlavaNextVideoPreTrainedModel): video_feature_lens = [feature.size(0) for feature in video_features] video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - n_video_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_features = video_features.shape[0] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index fecd2320f9..15f3ad9141 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -33,7 +33,7 @@ from ...cache_utils import Cache from ...configuration_utils import PretrainedConfig from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack -from ...utils import is_torchdynamo_compiling, logging +from ...utils import logging from ..auto import CONFIG_MAPPING, AutoConfig @@ -397,6 +397,46 @@ class LlavaNextVideoModel(LlavaNextModel): video_features = torch.split(video_features, frames, dim=0) return video_features + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + def forward( self, input_ids: torch.LongTensor = None, @@ -436,35 +476,20 @@ class LlavaNextVideoModel(LlavaNextModel): if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) - if pixel_values is not None and pixel_values.size(0) > 0: + if pixel_values is not None: image_features = self.get_image_features( pixel_values, image_sizes, vision_feature_layer=self.vision_feature_layer, vision_feature_select_strategy=self.vision_feature_select_strategy, ) - image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: + if pixel_values_videos is not None: video_features = self.get_video_features( pixel_values_videos, vision_feature_layer=self.vision_feature_layer, @@ -474,25 +499,12 @@ class LlavaNextVideoModel(LlavaNextModel): video_feature_lens = [feature.size(0) for feature in video_features] video_features = torch.cat(video_features, dim=0) video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - n_video_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_features = video_features.shape[0] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 41c39d26f3..d63cde2b57 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -39,7 +39,6 @@ from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, - is_torchdynamo_compiling, logging, ) from ..auto import AutoModel @@ -491,6 +490,46 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): ) return image_features + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + @can_return_tuple @auto_docstring def forward( @@ -557,24 +596,10 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): batch_num_images=batch_num_images, ) image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Video are simply embedded and further pooled to decrease seq len @@ -588,25 +613,10 @@ class LlavaOnevisionModel(LlavaOnevisionPreTrainedModel): self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device) ) video_features = torch.cat((video_features, image_newline), dim=1) - video_features = video_features.flatten(0, 1) - - if input_ids is None: - special_video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_video_mask = special_video_mask.all(-1) - else: - special_video_mask = input_ids == self.config.video_token_id - - n_video_tokens = (special_video_mask).sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): - n_video_features = video_features.shape[0] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + video_features = video_features.flatten(0, 1).to(inputs_embeds.device, inputs_embeds.dtype) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( diff --git a/src/transformers/models/llava_onevision/modular_llava_onevision.py b/src/transformers/models/llava_onevision/modular_llava_onevision.py index 233a96110b..14a8f39915 100644 --- a/src/transformers/models/llava_onevision/modular_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modular_llava_onevision.py @@ -50,7 +50,6 @@ from ...utils import ( TensorType, auto_docstring, can_return_tuple, - is_torchdynamo_compiling, is_torchvision_available, is_torchvision_v2_available, logging, @@ -541,24 +540,10 @@ class LlavaOnevisionModel(LlavaNextVideoModel): batch_num_images=batch_num_images, ) image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Video are simply embedded and further pooled to decrease seq len @@ -572,25 +557,10 @@ class LlavaOnevisionModel(LlavaNextVideoModel): self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device) ) video_features = torch.cat((video_features, image_newline), dim=1) - video_features = video_features.flatten(0, 1) - - if input_ids is None: - special_video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_video_mask = special_video_mask.all(-1) - else: - special_video_mask = input_ids == self.config.video_token_id - - n_video_tokens = (special_video_mask).sum() - special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_video_mask].numel() != video_features.numel(): - n_video_features = video_features.shape[0] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + video_features = video_features.flatten(0, 1).to(inputs_embeds.device, inputs_embeds.dtype) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( diff --git a/src/transformers/models/mistral3/modeling_mistral3.py b/src/transformers/models/mistral3/modeling_mistral3.py index 7e8dabef15..2ac42c4e10 100644 --- a/src/transformers/models/mistral3/modeling_mistral3.py +++ b/src/transformers/models/mistral3/modeling_mistral3.py @@ -33,7 +33,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ..auto import AutoModel from .configuration_mistral3 import Mistral3Config @@ -262,6 +262,30 @@ class Mistral3Model(Mistral3PreTrainedModel): image_features = torch.split(image_features.squeeze(0), split_sizes) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -302,25 +326,10 @@ class Mistral3Model(Mistral3PreTrainedModel): vision_feature_layer=vision_feature_layer, image_sizes=image_sizes, ) - image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/mistral3/modular_mistral3.py b/src/transformers/models/mistral3/modular_mistral3.py index da507655fa..454904c16b 100644 --- a/src/transformers/models/mistral3/modular_mistral3.py +++ b/src/transformers/models/mistral3/modular_mistral3.py @@ -22,7 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...processing_utils import Unpack -from ...utils import is_torchdynamo_compiling, logging +from ...utils import logging from ..llava.modeling_llava import ( LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration, @@ -200,25 +200,10 @@ class Mistral3Model(LlavaModel): vision_feature_layer=vision_feature_layer, image_sizes=image_sizes, ) - image_features = torch.cat(image_features, dim=0) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 2d82dccc18..29a14e555a 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -32,7 +32,6 @@ from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, - is_torchdynamo_compiling, logging, ) from ..auto import AutoModel @@ -251,6 +250,30 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -332,25 +355,10 @@ class PaliGemmaModel(PaliGemmaPreTrainedModel): # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) causal_mask = self._update_causal_mask( diff --git a/src/transformers/models/perception_lm/modeling_perception_lm.py b/src/transformers/models/perception_lm/modeling_perception_lm.py index cb99ca8e19..4210cd73e5 100644 --- a/src/transformers/models/perception_lm/modeling_perception_lm.py +++ b/src/transformers/models/perception_lm/modeling_perception_lm.py @@ -207,6 +207,46 @@ class PerceptionLMModel(PerceptionLMPreTrainedModel): image_features = self.multi_modal_projector(image_outputs) return image_features + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.size()[:-1].numel()}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.size()[:-1].numel()}" + ) + + return special_image_mask, special_video_mask + @can_return_tuple @auto_docstring def forward( @@ -241,24 +281,20 @@ class PerceptionLMModel(PerceptionLMPreTrainedModel): image_features = None if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values.to(inputs_embeds), + image_features = self.get_image_features(pixel_values=pixel_values) + image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - self.check_mask_feature_size_match(special_image_mask, image_features) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - image_features = image_features.to(inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) video_features = None if pixel_values_videos is not None: - video_features = self.get_image_features( - pixel_values=pixel_values_videos.to(inputs_embeds), + video_features = self.get_image_features(pixel_values=pixel_values_videos) + video_features = video_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features ) - special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - self.check_mask_feature_size_match(special_video_mask, video_features) - special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - video_features = video_features.to(inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( @@ -283,14 +319,6 @@ class PerceptionLMModel(PerceptionLMPreTrainedModel): video_hidden_states=(video_features if pixel_values_videos is not None else None), ) - def check_mask_feature_size_match(self, media_mask, media_features): - media_token_count = media_mask.sum() - media_feature_size = media_features.size()[:-1].numel() - if media_token_count != media_feature_size: - raise ValueError( - f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})" - ) - @auto_docstring class PerceptionLMForConditionalGeneration(PerceptionLMPreTrainedModel, GenerationMixin): diff --git a/src/transformers/models/perception_lm/modular_perception_lm.py b/src/transformers/models/perception_lm/modular_perception_lm.py index 3258fcd79f..bb1a03a040 100644 --- a/src/transformers/models/perception_lm/modular_perception_lm.py +++ b/src/transformers/models/perception_lm/modular_perception_lm.py @@ -168,13 +168,45 @@ class PerceptionLMModel(LlavaModel): image_features = self.multi_modal_projector(image_outputs) return image_features - def check_mask_feature_size_match(self, media_mask, media_features): - media_token_count = media_mask.sum() - media_feature_size = media_features.size()[:-1].numel() - if media_token_count != media_feature_size: - raise ValueError( - f"The number of tokens in the media mask ({media_token_count}) does not match the number of features in the media features ({media_feature_size}. Features shape: {media_features.shape})" + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.size()[:-1].numel()}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.size()[:-1].numel()}" + ) + + return special_image_mask, special_video_mask @can_return_tuple @auto_docstring @@ -210,24 +242,20 @@ class PerceptionLMModel(LlavaModel): image_features = None if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values.to(inputs_embeds), + image_features = self.get_image_features(pixel_values=pixel_values) + image_features = image_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features ) - special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - self.check_mask_feature_size_match(special_image_mask, image_features) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - image_features = image_features.to(inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) video_features = None if pixel_values_videos is not None: - video_features = self.get_image_features( - pixel_values=pixel_values_videos.to(inputs_embeds), + video_features = self.get_image_features(pixel_values=pixel_values_videos) + video_features = video_features.to(inputs_embeds.device, dtype=inputs_embeds.dtype) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features ) - special_video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1) - self.check_mask_feature_size_match(special_video_mask, video_features) - special_video_mask = special_video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - video_features = video_features.to(inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index a10a0955f8..68e3b1982f 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -1771,6 +1771,54 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo return audio_features + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + special_audio_mask = input_ids == self.config.audio_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask, special_video_mask, special_audio_mask + @auto_docstring def forward( self, @@ -1870,44 +1918,24 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths, ) - if input_ids is None: - audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - audio_mask = audio_mask.all(-1) - else: - audio_mask = input_ids == self.config.audio_token_id - - audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.video_token_id - - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if feature_attention_mask is not None: diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 56fdff57e5..01995f56d6 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -2216,6 +2216,54 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo return audio_features + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + special_audio_mask = ( + inputs_embeds + == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + ).all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + special_audio_mask = input_ids == self.config.audio_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + return special_image_mask, special_video_mask, special_audio_mask + @auto_docstring def forward( self, @@ -2315,44 +2363,24 @@ class Qwen2_5OmniThinkerForConditionalGeneration(Qwen2_5OmniPreTrainedModelForCo feature_attention_mask=feature_attention_mask, audio_feature_lengths=audio_feature_lengths, ) - if input_ids is None: - audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - audio_mask = audio_mask.all(-1) - else: - audio_mask = input_ids == self.config.audio_token_id - - audio_mask = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features) if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.video_token_id - - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if feature_attention_mask is not None: diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index c270e2714c..0c65b91376 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1190,6 +1190,46 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): image_embeds = torch.split(image_embeds, split_sizes) return image_embeds + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + @auto_docstring def forward( self, @@ -1233,47 +1273,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (image_mask).sum() - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.video_token_id - - n_video_tokens = (video_mask).sum() - n_video_features = video_embeds.shape[0] - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 7c4e7117a2..b4e521ff93 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -575,47 +575,18 @@ class Qwen2_5_VLModel(Qwen2VLModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (image_mask).sum() - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - video_mask = video_mask.all(-1) - else: - video_mask = input_ids == self.config.video_token_id - - n_video_tokens = (video_mask).sum() - n_video_features = video_embeds.shape[0] - video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 068199e6d9..c49c10714a 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1127,6 +1127,46 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): image_embeds = torch.split(image_embeds, split_sizes) return image_embeds + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + @auto_docstring def forward( self, @@ -1167,48 +1207,18 @@ class Qwen2VLModel(Qwen2VLPreTrainedModel): if pixel_values is not None: image_embeds = self.get_image_features(pixel_values, image_grid_thw) - image_embeds = torch.cat(image_embeds, dim=0) - - if input_ids is None: - image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - image_mask = image_mask.all(-1) - else: - image_mask = input_ids == self.config.image_token_id - - n_image_tokens = image_mask.sum() - image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - n_image_features = image_embeds.shape[0] - if not is_torchdynamo_compiling() and n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) if pixel_values_videos is not None: video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) - video_embeds = torch.cat(video_embeds, dim=0) - - if input_ids is None: - video_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - n_video_tokens = (video_mask).sum(dim=1).sum(dim=0)[0] - else: - video_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) - video_mask = video_mask.expand_as(inputs_embeds).to(inputs_embeds.device) - n_video_tokens = (input_ids == self.config.image_token_id).sum() - - n_video_features = video_embeds.shape[0] - if not is_torchdynamo_compiling() and n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if position_ids is None: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index befa350b90..98786832c4 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -28,7 +28,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ..auto import AutoModel from .configuration_video_llava import VideoLlavaConfig @@ -282,6 +282,46 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): return video_features, num_frames + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0] * image_features.shape[1]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and image tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0] * video_features.shape[1]}" + ) + + return special_image_mask, special_video_mask + @can_return_tuple @auto_docstring def forward( @@ -334,48 +374,21 @@ class VideoLlavaModel(VideoLlavaPreTrainedModel): vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if pixel_values_videos is not None: video_features, num_frames = self.get_video_features( pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.video_token_id - - n_video_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != video_features.numel(): - n_video_features = video_features.shape[0] * video_features.shape[1] - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + _, special_video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_video_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 3c9d5cd0ee..f18e25c0bc 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -30,7 +30,7 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils import auto_docstring, can_return_tuple from ..auto import AutoModel from .configuration_vipllava import VipLlavaConfig @@ -186,6 +186,30 @@ class VipLlavaModel(VipLlavaPreTrainedModel): image_features = self.multi_modal_projector(image_features) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @auto_docstring def forward( self, @@ -227,24 +251,10 @@ class VipLlavaModel(VipLlavaPreTrainedModel): image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/src/transformers/models/vipllava/modular_vipllava.py b/src/transformers/models/vipllava/modular_vipllava.py index 74cefba7a9..7200659793 100644 --- a/src/transformers/models/vipllava/modular_vipllava.py +++ b/src/transformers/models/vipllava/modular_vipllava.py @@ -28,7 +28,7 @@ from transformers.models.llava.modeling_llava import ( from ...activations import ACT2FN from ...cache_utils import Cache -from ...utils import auto_docstring, is_torchdynamo_compiling, logging +from ...utils import auto_docstring, logging from .configuration_vipllava import VipLlavaConfig @@ -144,24 +144,10 @@ class VipLlavaModel(LlavaModel): image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layers=vision_feature_layers ) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - n_image_tokens = (special_image_mask).sum() - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( diff --git a/tests/models/deepseek_vl/test_modeling_deepseek_vl.py b/tests/models/deepseek_vl/test_modeling_deepseek_vl.py index bff23e9dd5..b891612055 100644 --- a/tests/models/deepseek_vl/test_modeling_deepseek_vl.py +++ b/tests/models/deepseek_vl/test_modeling_deepseek_vl.py @@ -89,9 +89,9 @@ class DeepseekVLModelTester: self.hidden_size = text_config["hidden_size"] self.num_attention_heads = text_config["num_attention_heads"] self.image_size = vision_config["image_size"] - self.num_image_tokens = vision_config["image_size"] // vision_config["patch_size"] + self.num_image_tokens = 16 self.pad_token_id = text_config["pad_token_id"] - self.image_token_id = self.vocab_size - 1 + self.image_token_id = 0 def get_config(self): return DeepseekVLConfig( @@ -115,6 +115,7 @@ class DeepseekVLModelTester: ] ) # fill image_tokens + input_ids[input_ids == self.num_image_tokens] = config.text_config.pad_token_id input_ids[:, : self.num_image_tokens] = self.image_token_id return config, input_ids, attention_mask, pixel_values diff --git a/tests/models/emu3/test_modeling_emu3.py b/tests/models/emu3/test_modeling_emu3.py index 6c4f718590..8975cfe4a0 100644 --- a/tests/models/emu3/test_modeling_emu3.py +++ b/tests/models/emu3/test_modeling_emu3.py @@ -198,12 +198,12 @@ class Emu3Vision2TextModelTester: bos_token_id=1, eos_token_id=2, image_token_id=3, - image_size=30, + image_size=15, codebook_size=20, temporal_downsample_factor=1, base_channels=32, - vq_channel_multiplier=[1, 1], - image_seq_length=100, + vq_channel_multiplier=[1, 2, 1], + image_seq_length=12, vq_img_token_start_id=3, ): self.parent = parent @@ -288,6 +288,7 @@ class Emu3Vision2TextModelTester: "base_channels": self.base_channels, "channel_multiplier": self.vq_channel_multiplier, "hidden_size": self.base_channels, + "attn_resolutions": [], } return Emu3Config(text_config=text_config, vq_config=vq_config, vocabulary_map=vocab_map) @@ -358,6 +359,10 @@ class Emu3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline def test_generate_with_static_cache(self): pass + # @unittest.skip("Emu3 can't be smaller than currently if we want to downsample images") + # def test_model_is_small(self): + # pass + @require_torch class Emu3IntegrationTest(unittest.TestCase): diff --git a/tests/models/janus/test_modeling_janus.py b/tests/models/janus/test_modeling_janus.py index 254da46ac8..2fa257bc5c 100644 --- a/tests/models/janus/test_modeling_janus.py +++ b/tests/models/janus/test_modeling_janus.py @@ -89,7 +89,7 @@ class JanusVisionText2TextModelTester: "use_labels": True, "image_size": 20, "patch_size": 5, - "num_image_tokens": 4, + "num_image_tokens": 16, "num_channels": 3, "is_training": True, "hidden_size": 32, diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 2b134bf5a0..56fe8dcbb2 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -61,6 +61,7 @@ class LlavaOnevisionVisionText2TextModelTester: parent, ignore_index=-100, image_token_index=1, + video_token_index=2, projector_hidden_act="gelu", seq_length=7, vision_feature_select_strategy="full", @@ -108,6 +109,7 @@ class LlavaOnevisionVisionText2TextModelTester: self.parent = parent self.ignore_index = ignore_index self.image_token_index = image_token_index + self.video_token_index = video_token_index self.projector_hidden_act = projector_hidden_act self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer @@ -134,6 +136,7 @@ class LlavaOnevisionVisionText2TextModelTester: vision_config=self.vision_config, ignore_index=self.ignore_index, image_token_index=self.image_token_index, + video_token_index=self.video_token_index, projector_hidden_act=self.projector_hidden_act, vision_feature_select_strategy=self.vision_feature_select_strategy, vision_feature_layer=self.vision_feature_layer,