[vlm] fix loading of retrieval VLMs (#39242)

* fix vlm with retrieval

* we can't use AutoModel because new ColQwen was released after refactor

* no need for colqwen

* tied weight keys are necessary, if using IMageTextToText

* need to apply renaming in tied weights, only for ColPali

* overwrite tied keys in ColPali

* fix copies, modular can't handle if-statements
This commit is contained in:
Raushan Turganbay
2025-07-15 20:23:54 +05:00
committed by GitHub
parent b1d14086e4
commit 9f41f67135
5 changed files with 67 additions and 24 deletions

View File

@@ -231,6 +231,7 @@ TORCH_INIT_FUNCTIONS = {
VLMS = [ VLMS = [
"aria", "aria",
"ayavision", "ayavision",
"colpali",
"emu3", "emu3",
"fuyu", "fuyu",
"gotocr2", "gotocr2",

View File

@@ -97,15 +97,20 @@ class ColPaliForRetrievalOutput(ModelOutput):
""" """
) )
class ColPaliForRetrieval(ColPaliPreTrainedModel): class ColPaliForRetrieval(ColPaliPreTrainedModel):
_checkpoint_conversion_mapping = {
"vlm.language_model.model": "vlm.model.language_model",
"vlm.vision_tower": "vlm.model.vision_tower",
"vlm.multi_modal_projector": "vlm.model.multi_modal_projector",
"vlm.language_model.lm_head": "vlm.lm_head",
}
def __init__(self, config: ColPaliConfig): def __init__(self, config: ColPaliConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.vocab_size = config.vlm_config.text_config.vocab_size self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config) self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
if vlm._tied_weights_keys is not None: self._tied_weights_keys = [f"vlm.language_model.{k}" for k in (self.vlm._tied_weights_keys or [])]
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim self.embedding_dim = self.config.embedding_dim
self.embedding_proj_layer = nn.Linear( self.embedding_proj_layer = nn.Linear(
@@ -136,7 +141,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vlm_output = self.vlm( vlm_output = self.vlm.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
pixel_values=pixel_values, pixel_values=pixel_values,
@@ -148,7 +153,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
last_hidden_states = vlm_output.hidden_states[-1] # (batch_size, sequence_length, hidden_size) last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim) embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization # L2 normalization
@@ -177,12 +182,6 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.vlm.set_output_embeddings(new_embeddings) self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.vlm.set_decoder(decoder)
def get_decoder(self):
return self.vlm.get_decoder()
def tie_weights(self): def tie_weights(self):
return self.vlm.tie_weights() return self.vlm.tie_weights()

View File

@@ -104,21 +104,21 @@ class ColQwen2ForRetrievalOutput(ModelOutput):
""" """
) )
class ColQwen2ForRetrieval(ColQwen2PreTrainedModel): class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
_checkpoint_conversion_mapping = {}
def __init__(self, config: ColQwen2Config): def __init__(self, config: ColQwen2Config):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.vocab_size = config.vlm_config.text_config.vocab_size self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config) self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
if vlm._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim self.embedding_dim = self.config.embedding_dim
self.embedding_proj_layer = nn.Linear( self.embedding_proj_layer = nn.Linear(
self.config.vlm_config.text_config.hidden_size, self.config.vlm_config.text_config.hidden_size,
self.embedding_dim, self.embedding_dim,
) )
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
self.post_init() self.post_init()
@@ -172,7 +172,7 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids) inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
@@ -228,12 +228,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.vlm.set_output_embeddings(new_embeddings) self.vlm.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.vlm.set_decoder(decoder)
def get_decoder(self):
return self.vlm.get_decoder()
def tie_weights(self): def tie_weights(self):
return self.vlm.tie_weights() return self.vlm.tie_weights()

View File

@@ -25,6 +25,7 @@ from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessingKwargs, Unpack from ...processing_utils import ProcessingKwargs, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
from .configuration_colqwen2 import ColQwen2Config
if is_torch_available(): if is_torch_available():
@@ -272,6 +273,13 @@ class ColQwen2ForRetrievalOutput(ModelOutput):
""" """
) )
class ColQwen2ForRetrieval(ColPaliForRetrieval): class ColQwen2ForRetrieval(ColPaliForRetrieval):
_checkpoint_conversion_mapping = {}
def __init__(self, config: ColQwen2Config):
super().__init__(config)
del self._tied_weights_keys
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
@can_return_tuple @can_return_tuple
@auto_docstring @auto_docstring
def forward( def forward(
@@ -322,7 +330,7 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval):
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs. # Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids) inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:
pixel_values = pixel_values.type(self.vlm.visual.get_dtype()) pixel_values = pixel_values.type(self.vlm.visual.get_dtype())

View File

@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
"""Testing suite for the PyTorch ColPali model.""" """Testing suite for the PyTorch ColPali model."""
import collections
import gc import gc
import re
import unittest import unittest
from typing import ClassVar from typing import ClassVar
@@ -40,6 +42,8 @@ from transformers.testing_utils import (
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers.pytorch_utils import id_tensor_storage
class ColPaliForRetrievalModelTester: class ColPaliForRetrievalModelTester:
def __init__( def __init__(
@@ -206,6 +210,43 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsInstance(outputs, ColPaliForRetrievalOutput) self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
# ColPali uses a VLM internally which has its state dict keys renames with `conversion_mapping`
# This test is written assuming that `_tied_weights_keys` are not going to be renamed, thus we
# overwrite it. NOTE: ColPali inference/save/load works without issues, it is the testcase
# that makes general assumptions
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.get_text_config().tie_word_embeddings = True
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
key = key.replace(".language_model", "") # remove 'language_model' prefix
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
key = key.replace(".language_model", "") # remove 'language_model' prefix
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
@unittest.skip( @unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
) )