[vlm] fix loading of retrieval VLMs (#39242)
* fix vlm with retrieval * we can't use AutoModel because new ColQwen was released after refactor * no need for colqwen * tied weight keys are necessary, if using IMageTextToText * need to apply renaming in tied weights, only for ColPali * overwrite tied keys in ColPali * fix copies, modular can't handle if-statements
This commit is contained in:
committed by
GitHub
parent
b1d14086e4
commit
9f41f67135
@@ -231,6 +231,7 @@ TORCH_INIT_FUNCTIONS = {
|
|||||||
VLMS = [
|
VLMS = [
|
||||||
"aria",
|
"aria",
|
||||||
"ayavision",
|
"ayavision",
|
||||||
|
"colpali",
|
||||||
"emu3",
|
"emu3",
|
||||||
"fuyu",
|
"fuyu",
|
||||||
"gotocr2",
|
"gotocr2",
|
||||||
|
|||||||
@@ -97,15 +97,20 @@ class ColPaliForRetrievalOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
||||||
|
_checkpoint_conversion_mapping = {
|
||||||
|
"vlm.language_model.model": "vlm.model.language_model",
|
||||||
|
"vlm.vision_tower": "vlm.model.vision_tower",
|
||||||
|
"vlm.multi_modal_projector": "vlm.model.multi_modal_projector",
|
||||||
|
"vlm.language_model.lm_head": "vlm.lm_head",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, config: ColPaliConfig):
|
def __init__(self, config: ColPaliConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vlm_config.text_config.vocab_size
|
self.vocab_size = config.vlm_config.text_config.vocab_size
|
||||||
|
|
||||||
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
|
self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
|
||||||
if vlm._tied_weights_keys is not None:
|
self._tied_weights_keys = [f"vlm.language_model.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
||||||
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
|
|
||||||
self.vlm = vlm
|
|
||||||
|
|
||||||
self.embedding_dim = self.config.embedding_dim
|
self.embedding_dim = self.config.embedding_dim
|
||||||
self.embedding_proj_layer = nn.Linear(
|
self.embedding_proj_layer = nn.Linear(
|
||||||
@@ -136,7 +141,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
vlm_output = self.vlm(
|
vlm_output = self.vlm.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@@ -148,7 +153,7 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
|||||||
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
||||||
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
|
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
|
||||||
|
|
||||||
last_hidden_states = vlm_output.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
last_hidden_states = vlm_output[0] # (batch_size, sequence_length, hidden_size)
|
||||||
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
|
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||||
|
|
||||||
# L2 normalization
|
# L2 normalization
|
||||||
@@ -177,12 +182,6 @@ class ColPaliForRetrieval(ColPaliPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.vlm.set_output_embeddings(new_embeddings)
|
self.vlm.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
def set_decoder(self, decoder):
|
|
||||||
self.vlm.set_decoder(decoder)
|
|
||||||
|
|
||||||
def get_decoder(self):
|
|
||||||
return self.vlm.get_decoder()
|
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
return self.vlm.tie_weights()
|
return self.vlm.tie_weights()
|
||||||
|
|
||||||
|
|||||||
@@ -104,21 +104,21 @@ class ColQwen2ForRetrievalOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
||||||
|
_checkpoint_conversion_mapping = {}
|
||||||
|
|
||||||
def __init__(self, config: ColQwen2Config):
|
def __init__(self, config: ColQwen2Config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vocab_size = config.vlm_config.text_config.vocab_size
|
self.vocab_size = config.vlm_config.text_config.vocab_size
|
||||||
|
|
||||||
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
|
self.vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
|
||||||
if vlm._tied_weights_keys is not None:
|
|
||||||
self._tied_weights_keys = [f"vlm.{k}" for k in vlm._tied_weights_keys]
|
|
||||||
self.vlm = vlm
|
|
||||||
|
|
||||||
self.embedding_dim = self.config.embedding_dim
|
self.embedding_dim = self.config.embedding_dim
|
||||||
self.embedding_proj_layer = nn.Linear(
|
self.embedding_proj_layer = nn.Linear(
|
||||||
self.config.vlm_config.text_config.hidden_size,
|
self.config.vlm_config.text_config.hidden_size,
|
||||||
self.embedding_dim,
|
self.embedding_dim,
|
||||||
)
|
)
|
||||||
|
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
|||||||
|
|
||||||
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids)
|
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
||||||
@@ -228,12 +228,6 @@ class ColQwen2ForRetrieval(ColQwen2PreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.vlm.set_output_embeddings(new_embeddings)
|
self.vlm.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
def set_decoder(self, decoder):
|
|
||||||
self.vlm.set_decoder(decoder)
|
|
||||||
|
|
||||||
def get_decoder(self):
|
|
||||||
return self.vlm.get_decoder()
|
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
return self.vlm.tie_weights()
|
return self.vlm.tie_weights()
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from ...image_utils import ImageInput, is_valid_image
|
|||||||
from ...processing_utils import ProcessingKwargs, Unpack
|
from ...processing_utils import ProcessingKwargs, Unpack
|
||||||
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
from ...tokenization_utils_base import PreTokenizedInput, TextInput
|
||||||
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torch_available, logging
|
||||||
|
from .configuration_colqwen2 import ColQwen2Config
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -272,6 +273,13 @@ class ColQwen2ForRetrievalOutput(ModelOutput):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
||||||
|
_checkpoint_conversion_mapping = {}
|
||||||
|
|
||||||
|
def __init__(self, config: ColQwen2Config):
|
||||||
|
super().__init__(config)
|
||||||
|
del self._tied_weights_keys
|
||||||
|
self._tied_weights_keys = [f"vlm.{k}" for k in (self.vlm._tied_weights_keys or [])]
|
||||||
|
|
||||||
@can_return_tuple
|
@can_return_tuple
|
||||||
@auto_docstring
|
@auto_docstring
|
||||||
def forward(
|
def forward(
|
||||||
@@ -322,7 +330,7 @@ class ColQwen2ForRetrieval(ColPaliForRetrieval):
|
|||||||
|
|
||||||
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
# Custom data preparation to fix an issue with the gradient flow when training with multiple GPUs.
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.vlm.model.language_model.embed_tokens(input_ids)
|
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
|
||||||
|
|||||||
@@ -13,7 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Testing suite for the PyTorch ColPali model."""
|
"""Testing suite for the PyTorch ColPali model."""
|
||||||
|
|
||||||
|
import collections
|
||||||
import gc
|
import gc
|
||||||
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
@@ -40,6 +42,8 @@ from transformers.testing_utils import (
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from transformers.pytorch_utils import id_tensor_storage
|
||||||
|
|
||||||
|
|
||||||
class ColPaliForRetrievalModelTester:
|
class ColPaliForRetrievalModelTester:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -206,6 +210,43 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
|
self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
|
||||||
|
|
||||||
|
# ColPali uses a VLM internally which has its state dict keys renames with `conversion_mapping`
|
||||||
|
# This test is written assuming that `_tied_weights_keys` are not going to be renamed, thus we
|
||||||
|
# overwrite it. NOTE: ColPali inference/save/load works without issues, it is the testcase
|
||||||
|
# that makes general assumptions
|
||||||
|
def test_tied_weights_keys(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.get_text_config().tie_word_embeddings = True
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model_tied = model_class(config)
|
||||||
|
|
||||||
|
ptrs = collections.defaultdict(list)
|
||||||
|
for name, tensor in model_tied.state_dict().items():
|
||||||
|
ptrs[id_tensor_storage(tensor)].append(name)
|
||||||
|
|
||||||
|
# These are all the pointers of shared tensors.
|
||||||
|
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||||
|
|
||||||
|
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||||
|
# Detect we get a hit for each key
|
||||||
|
for key in tied_weight_keys:
|
||||||
|
key = key.replace(".language_model", "") # remove 'language_model' prefix
|
||||||
|
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||||
|
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||||
|
|
||||||
|
# Removed tied weights found from tied params -> there should only be one left after
|
||||||
|
for key in tied_weight_keys:
|
||||||
|
key = key.replace(".language_model", "") # remove 'language_model' prefix
|
||||||
|
for i in range(len(tied_params)):
|
||||||
|
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||||
|
|
||||||
|
tied_params = [group for group in tied_params if len(group) > 1]
|
||||||
|
self.assertListEqual(
|
||||||
|
tied_params,
|
||||||
|
[],
|
||||||
|
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
||||||
|
)
|
||||||
|
|
||||||
@unittest.skip(
|
@unittest.skip(
|
||||||
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user