From 9f41f67135b0656c428ff2c2b446d8eb15f5a7c5 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 15 Jul 2025 20:23:54 +0500 Subject: [PATCH] [vlm] fix loading of retrieval VLMs (#39242) * fix vlm with retrieval * we can't use AutoModel because new ColQwen was released after refactor * no need for colqwen * tied weight keys are necessary, if using IMageTextToText * need to apply renaming in tied weights, only for ColPali * overwrite tied keys in ColPali * fix copies, modular can't handle if-statements --- src/transformers/modeling_utils.py | 1 + .../models/colpali/modeling_colpali.py | 23 +++++------ .../models/colqwen2/modeling_colqwen2.py | 16 +++----- .../models/colqwen2/modular_colqwen2.py | 10 ++++- tests/models/colpali/test_modeling_colpali.py | 41 +++++++++++++++++++ 5 files changed, 67 insertions(+), 24 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 8e10c8eef5..3202ef47c1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -231,6 +231,7 @@ TORCH_INIT_FUNCTIONS = { VLMS = [ "aria", "ayavision", + "colpali", "emu3", "fuyu", "gotocr2", diff --git a/src/transformers/models/colpali/modeling_colpali.py b/src/transformers/models/colpali/modeling_colpali.py index 0ac554ad2e..1f38ea407a 100644 --- a/src/transformers/models/colpali/modeling_colpali.py +++ b/src/transformers/models/colpali/modeling_colpali.py @@ -97,15 +97,20 @@ class ColPaliForRetrievalOutput(ModelOutput): """ ) class ColPaliForRetrieval(ColPaliPreTrainedModel): + _checkpoint_conversion_mapping = { + "vlm.language_model.model": "vlm.model.language_model", + "vlm.vision_tower": "vlm.model.vision_tower", + "vlm.multi_modal_projector": "vlm.model.multi_modal_projector", + "vlm.language_model.lm_head": "vlm.lm_head", + } + def __init__(self, config: ColPaliConfig): super().__init__(config) self.config = config self.vocab_size = config.vlm_config.text_config.vocab_size - vlm = AutoModelForImageTextToText.from_config(config.vlm_config) - if vlm._tied_weights_keys is not None: - self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys] - self.vlm = vlm + self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config) + self._tied_weights_keys = [f"vlm.language_model.{k}" for k in (self.vlm._tied_weights_keys or [])] self.embedding_dim = self.config.embedding_dim self.embedding_proj_layer = nn.Linear( @@ -136,7 +141,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - vlm_output = self.vlm( + vlm_output = self.vlm.model( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, @@ -148,7 +153,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None - last_hidden_states = vlm_output.hidden_states[-1] # (batch_size, sequence_length, hidden_size) + last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size) embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim) # L2 normalization @@ -177,12 +182,6 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.vlm.set_output_embeddings(new_embeddings) - def set_decoder(self, decoder): - self.vlm.set_decoder(decoder) - - def get_decoder(self): - return self.vlm.get_decoder() - def tie_weights(self): return self.vlm.tie_weights() diff --git a/src/transformers/models/colqwen2/modeling_colqwen2.py b/src/transformers/models/colqwen2/modeling_colqwen2.py index 0a64bcc60e..b0703f665e 100644 --- a/src/transformers/models/colqwen2/modeling_colqwen2.py +++ b/src/transformers/models/colqwen2/modeling_colqwen2.py @@ -104,21 +104,21 @@ class ColQwen2ForRetrievalOutput(ModelOutput): """ ) class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): + _checkpoint_conversion_mapping = {} + def __init__(self, config: ColQwen2Config): super().__init__(config) self.config = config self.vocab_size = config.vlm_config.text_config.vocab_size - vlm = AutoModelForImageTextToText.from_config(config.vlm_config) - if vlm._tied_weights_keys is not None: - self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys] - self.vlm = vlm + self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config) self.embedding_dim = self.config.embedding_dim self.embedding_proj_layer = nn.Linear( self.config.vlm_config.text_config.hidden_size, self.embedding_dim, ) + self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] self.post_init() @@ -172,7 +172,7 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. if inputs_embeds is None: - inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids) + inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) @@ -228,12 +228,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): def set_output_embeddings(self, new_embeddings): self.vlm.set_output_embeddings(new_embeddings) - def set_decoder(self, decoder): - self.vlm.set_decoder(decoder) - - def get_decoder(self): - return self.vlm.get_decoder() - def tie_weights(self): return self.vlm.tie_weights() diff --git a/src/transformers/models/colqwen2/modular_colqwen2.py b/src/transformers/models/colqwen2/modular_colqwen2.py index 4f5ce4aa8a..f63e865a71 100644 --- a/src/transformers/models/colqwen2/modular_colqwen2.py +++ b/src/transformers/models/colqwen2/modular_colqwen2.py @@ -25,6 +25,7 @@ from ...image_utils import ImageInput, is_valid_image from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging +from .configuration_colqwen2 import ColQwen2Config if is_torch_available(): @@ -272,6 +273,13 @@ class ColQwen2ForRetrievalOutput(ModelOutput): """ ) class ColQwen2ForRetrieval(ColPaliForRetrieval): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: ColQwen2Config): + super().__init__(config) + del self._tied_weights_keys + self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])] + @can_return_tuple @auto_docstring def forward( @@ -322,7 +330,7 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval): # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. if inputs_embeds is None: - inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids) + inputs_embeds = self.vlm.language_model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) diff --git a/tests/models/colpali/test_modeling_colpali.py b/tests/models/colpali/test_modeling_colpali.py index 3cdfc06227..c6c9892d17 100644 --- a/tests/models/colpali/test_modeling_colpali.py +++ b/tests/models/colpali/test_modeling_colpali.py @@ -13,7 +13,9 @@ # limitations under the License. """Testing suite for the PyTorch ColPali model.""" +import collections import gc +import re import unittest from typing import ClassVar @@ -40,6 +42,8 @@ from transformers.testing_utils import ( if is_torch_available(): import torch + from transformers.pytorch_utils import id_tensor_storage + class ColPaliForRetrievalModelTester: def __init__( @@ -206,6 +210,43 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsInstance(outputs, ColPaliForRetrievalOutput) + # ColPali uses a VLM internally which has its state dict keys renames with `conversion_mapping` + # This test is written assuming that `_tied_weights_keys` are not going to be renamed, thus we + # overwrite it. NOTE: ColPali inference/save/load works without issues, it is the testcase + # that makes general assumptions + def test_tied_weights_keys(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + config.get_text_config().tie_word_embeddings = True + for model_class in self.all_model_classes: + model_tied = model_class(config) + + ptrs = collections.defaultdict(list) + for name, tensor in model_tied.state_dict().items(): + ptrs[id_tensor_storage(tensor)].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + + tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] + # Detect we get a hit for each key + for key in tied_weight_keys: + key = key.replace(".language_model", "") # remove 'language_model' prefix + is_tied_key = any(re.search(key, p) for group in tied_params for p in group) + self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.") + + # Removed tied weights found from tied params -> there should only be one left after + for key in tied_weight_keys: + key = key.replace(".language_model", "") # remove 'language_model' prefix + for i in range(len(tied_params)): + tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None] + + tied_params = [group for group in tied_params if len(group) > 1] + self.assertListEqual( + tied_params, + [], + f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.", + ) + @unittest.skip( reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" )