[vlm] fix loading of retrieval VLMs (#39242)
* fix vlm with retrieval * we can't use AutoModel because new ColQwen was released after refactor * no need for colqwen * tied weight keys are necessary, if using IMageTextToText * need to apply renaming in tied weights, only for ColPali * overwrite tied keys in ColPali * fix copies, modular can't handle if-statements
This commit is contained in:
committed by
GitHub
parent
b1d14086e4
commit
9f41f67135
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch ColPali model."""
|
||||
|
||||
import collections
|
||||
import gc
|
||||
import re
|
||||
import unittest
|
||||
from typing import ClassVar
|
||||
|
||||
@@ -40,6 +42,8 @@ from transformers.testing_utils import (
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.pytorch_utils import id_tensor_storage
|
||||
|
||||
|
||||
class ColPaliForRetrievalModelTester:
|
||||
def __init__(
|
||||
@@ -206,6 +210,43 @@ class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
|
||||
|
||||
# ColPali uses a VLM internally which has its state dict keys renames with `conversion_mapping`
|
||||
# This test is written assuming that `_tied_weights_keys` are not going to be renamed, thus we
|
||||
# overwrite it. NOTE: ColPali inference/save/load works without issues, it is the testcase
|
||||
# that makes general assumptions
|
||||
def test_tied_weights_keys(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.get_text_config().tie_word_embeddings = True
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
|
||||
ptrs = collections.defaultdict(list)
|
||||
for name, tensor in model_tied.state_dict().items():
|
||||
ptrs[id_tensor_storage(tensor)].append(name)
|
||||
|
||||
# These are all the pointers of shared tensors.
|
||||
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
|
||||
|
||||
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
|
||||
# Detect we get a hit for each key
|
||||
for key in tied_weight_keys:
|
||||
key = key.replace(".language_model", "") # remove 'language_model' prefix
|
||||
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
|
||||
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
|
||||
|
||||
# Removed tied weights found from tied params -> there should only be one left after
|
||||
for key in tied_weight_keys:
|
||||
key = key.replace(".language_model", "") # remove 'language_model' prefix
|
||||
for i in range(len(tied_params)):
|
||||
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
|
||||
|
||||
tied_params = [group for group in tied_params if len(group) > 1]
|
||||
self.assertListEqual(
|
||||
tied_params,
|
||||
[],
|
||||
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user