From 3dd1de39bb37887d2c883ecf82ddca6d17924875 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 6 Feb 2025 14:31:32 +0100 Subject: [PATCH] Paligemma: fix generation with Gemma2 (#36044) * fix paligemma * nit * use `kwargs` in models that can load any LM --- .../models/llava/modeling_llava.py | 2 + .../models/llava_next/modeling_llava_next.py | 2 + .../modeling_llava_next_video.py | 2 + .../modular_llava_next_video.py | 2 + .../modeling_llava_onevision.py | 2 + .../models/paligemma/modeling_paligemma.py | 14 +- .../video_llava/modeling_video_llava.py | 2 + .../models/vipllava/modeling_vipllava.py | 2 + tests/models/paligemma2/__init__.py | 0 .../paligemma2/test_modeling_paligemma2.py | 350 ++++++++++++++++++ 10 files changed, 372 insertions(+), 6 deletions(-) create mode 100644 tests/models/paligemma2/__init__.py create mode 100644 tests/models/paligemma2/test_modeling_paligemma2.py diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 50b666758f..36f212e768 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -425,6 +425,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: torch.Tensor = None, + **lm_kwargs, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: @@ -520,6 +521,7 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index b9387eaab0..06e1cc6394 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -794,6 +794,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: r""" Args: @@ -896,6 +897,7 @@ class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixi return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3e288520ed..f62824947d 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -829,6 +829,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -991,6 +992,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, Gene return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 77e0f08c7e..b2e06c337c 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -360,6 +360,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]: r""" Args: @@ -522,6 +523,7 @@ class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index b75ef9ab0b..ed584bda7f 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -619,6 +619,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, LlavaOnevisionCausalLMOutputWithPast]: r""" Args: @@ -766,6 +767,7 @@ class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, Gene return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 5889f92f3c..b6dab1830c 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -342,8 +342,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi token_type_ids, past_key_values, cache_position, - input_ids=None, - inputs_embeds=None, + input_tensor, is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": @@ -353,8 +352,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi using_static_cache = isinstance(past_key_values, StaticCache) min_dtype = torch.finfo(self.dtype).min - inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] - sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + inputs_lead_dim, sequence_length = input_tensor.shape[:2] if using_static_cache: target_length = past_key_values.get_max_cache_shape() elif isinstance(past_key_values, HybridCache): @@ -432,6 +430,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: r""" Args: @@ -524,7 +523,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) outputs = self.language_model( attention_mask=causal_mask, @@ -537,6 +536,7 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs.logits @@ -612,10 +612,12 @@ class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixi model_inputs["pixel_values"] = pixel_values is_training = token_type_ids is not None and labels is not None if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training ) model_inputs["attention_mask"] = causal_mask + return model_inputs diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index c7dd0a1f93..d8da974b98 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -463,6 +463,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" Args: @@ -616,6 +617,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMi return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 0fbef33086..71201db209 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -400,6 +400,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, ) -> Union[Tuple, VipLlavaCausalLMOutputWithPast]: r""" Args: @@ -490,6 +491,7 @@ class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin) return_dict=return_dict, cache_position=cache_position, logits_to_keep=logits_to_keep, + **lm_kwargs, ) logits = outputs[0] diff --git a/tests/models/paligemma2/__init__.py b/tests/models/paligemma2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/paligemma2/test_modeling_paligemma2.py b/tests/models/paligemma2/test_modeling_paligemma2.py new file mode 100644 index 0000000000..4a87eb329d --- /dev/null +++ b/tests/models/paligemma2/test_modeling_paligemma2.py @@ -0,0 +1,350 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch PaliGemma model.""" + +import unittest + +from transformers import ( + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + require_torch, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + +if is_vision_available(): + pass + + +class PaliGemma2VisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_index=0, + projector_hidden_act="gelu", + seq_length=25, + vision_feature_select_strategy="default", + vision_feature_layer=-1, + projection_dim=32, + text_config={ + "model_type": "gemma2", + "seq_length": 128, + "is_training": True, + # "use_input_mask": True, + "use_token_type_ids": False, + "use_labels": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "head_dim": 8, + "intermediate_size": 37, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 16, + "type_sequence_label_size": 2, + "initializer_range": 0.02, + "num_labels": 3, + "num_choices": 4, + "pad_token_id": 1, + }, + is_training=True, + vision_config={ + "use_labels": True, + "image_size": 20, + "patch_size": 5, + "num_image_tokens": 4, + "num_channels": 3, + "is_training": True, + "hidden_size": 32, + "projection_dim": 32, + "num_key_value_heads": 1, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "intermediate_size": 37, + "dropout": 0.1, + "attention_dropout": 0.1, + "initializer_range": 0.02, + }, + use_cache=False, + ): + self.parent = parent + self.ignore_index = ignore_index + # `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.seq_length = seq_length + self.projection_dim = projection_dim + self.pad_token_id = text_config["pad_token_id"] + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_channels = vision_config["num_channels"] + self.image_size = vision_config["image_size"] + self.encoder_seq_length = seq_length + self.use_cache = use_cache + + def get_config(self): + return PaliGemmaConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_index=self.image_token_index, + projector_hidden_act=self.projector_hidden_act, + projection_dim=self.projection_dim, + vision_feature_select_strategy=self.vision_feature_select_strategy, + vision_feature_layer=self.vision_feature_layer, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + attention_mask = input_ids.ne(self.pad_token_id).to(torch_device) + + # set the 16 first tokens to be image, and ensure that no other tokens are image tokens + # do not change this unless you modified image size or patch size + input_ids[input_ids == config.image_token_index] = self.pad_token_id + input_ids[:, :16] = config.image_token_index + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": input_ids, + "token_type_ids": torch.zeros_like(input_ids), + } + return config, inputs_dict + + +@require_torch +class PaliGemma2ForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `PaliGemmaForConditionalGeneration`. + """ + + all_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (PaliGemmaForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"image-text-to-text": PaliGemmaForConditionalGeneration} + fx_compatible = False + test_pruning = False + test_torchscript = False + test_head_masking = False + _is_composite = True + + def setUp(self): + self.model_tester = PaliGemma2VisionText2TextModelTester(self) + self.config_tester = ConfigTester(self, config_class=PaliGemmaConfig, has_text_modality=False) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + wte = model.get_input_embeddings() + inputs["inputs_embeds"] = wte(input_ids) + + with torch.no_grad(): + model(**inputs) + + # overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs + # while some other models require pixel_values to be present + def test_inputs_embeds_matches_input_ids(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + input_ids = inputs["input_ids"] + del inputs["input_ids"] + del inputs["pixel_values"] + + inputs_embeds = model.get_input_embeddings()(input_ids) + + with torch.no_grad(): + out_ids = model(input_ids=input_ids, **inputs)[0] + out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0] + torch.testing.assert_close(out_embeds, out_ids) + + # Copied from tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest.test_mismatching_num_image_tokens + def test_mismatching_num_image_tokens(self): + """ + Tests that VLMs through an error with explicit message saying what is wrong + when number of images don't match number of image tokens in the text. + Also we need to test multi-image cases when one prompr has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + _ = model(**input_dict) # successfull forward with no modifications + + # remove one image but leave the image token in text + input_dict["pixel_values"] = input_dict["pixel_values"][-1:, ...] + with self.assertRaises(ValueError): + _ = model(**input_dict) + + # simulate multi-image case by concatenating inputs where each has exactly one image/image-token + input_ids = input_dict["input_ids"][:1] + pixel_values = input_dict["pixel_values"][:1] + input_ids = torch.cat([input_ids, input_ids], dim=0) + + # one image and two image tokens raise an error + with self.assertRaises(ValueError): + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + # two images and two image tokens don't raise an error + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_cpu_offload(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_bin(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_disk_offload_safetensors(self): + pass + + @unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.") + def test_model_parallelism(self): + pass + + @unittest.skip( + reason="PaliGemmma's SigLip encoder uses the same initialization scheme as the Flax original implementation" + ) + def test_initialization(self): + pass + + # TODO extend valid outputs to include this test @Molbap + @unittest.skip(reason="PaliGemma has currently one output format.") + def test_model_outputs_equivalence(self): + pass + + # TODO fix the loss = nan in the testing configuration chosen @Molbap + @unittest.skip(reason="Edge case giving loss nan values in testing configuration.") + def test_determinism(self): + pass + + @unittest.skip(reason="PaliGemma does not use feedforward chunking.") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="PaliGemma does not support low_cpu_mem_usage.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip(reason="PaliGemma does not support low_cpu_mem_usage.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip(reason="PaliGemma does not support low_cpu_mem_usage.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + + @unittest.skip( + reason="VLMs doen't accept inputs embeds and pixel values at the same time. So if the test passed for bacbone LM, it passes for VLM also" + ) + def test_generate_from_inputs_embeds_with_static_cache(self): + pass + + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + # TODO (joao, raushan): fix me -- the problem is in `cache_position[0] == 0`, i.e. dynamic control flow + @unittest.skip("PaliGemma is not compatible with end-to-end generation compilation") + def test_generate_compile_model_forward(self): + pass + + @unittest.skip("Low memory will be removed soon so no need to fix it") + def test_beam_search_low_memory(self): + pass