Add ColPali to 🤗 transformers (#33736)

* feat: run `add-new-model-like`

* feat: add paligemma code with "copied from"

* feat: add ColPaliProcessor

* feat: add ColPaliModel

* feat: add ColPaliConfig

* feat: rename `ColPaliForConditionalGeneration` to `ColPaliModel`

* fixup modeling colpali

* fix: fix root import shortcuts

* fix: fix `modeling_auto` dict

* feat: comment out ColPali test file

* fix: fix typos from `add-new-model-like`

* feat: explicit the forward input args

* feat: move everything to `modular_colpali.py`

* fix: put back ColPaliProcesor

* feat: add auto-generated files

* fix: run `fix-copies`

* fix: remove DOCStRING constants to make modular converter work

* fix: fix typo + modular converter

* fix: add missing imports

* feat: no more errors when loading ColPaliModel

* fix: remove unused args in forward + tweak doc

* feat: rename `ColPaliModel` to `ColPaliForRetrieval`

* fix: apply `fix-copies`

* feat: add ColPaliProcessor to `modular_colpali`

* fix: run make quality + make style

* fix: remove duplicate line in configuration_auto

* feat: make ColPaliModel inehrit from PaliGemmaForConditionalGeneration

* fix: tweak and use ColPaliConfig

* feat: rename `score` to `post_process_retrieval`

* build: run modular formatter + make style

* feat: convert colpali weights + fixes

* feat: remove old weight converter file

* feat: add and validate tests

* feat: replace harcoded path to "vidore/colpali-v1.2-hf" in tests

* fix: add bfloat16 conversion in weight converter

* feat: replace pytest with unittest in modeling colpali test

* feat: add sanity check for weight conversion (doesn't work yet)

* feat: add shape sanity check in weigth converter

* feat: make ColPaliProcessor args explicit

* doc: add doc for ColPali

* fix: trying to fix output mismatch

* feat: tweaks

* fix: ColPaliModelOutput inherits from ModelOutput instead of PaliGemmaCausalLMOutputWithPast

* fix: address comments on PR

* fix: adapt tests to the Hf norm

* wip: try things

* feat: add `__call__` method to `ColPaliProcessor`

* feat: remove need for dummy image in `process_queries`

* build: run new modular converter

* fix: fix incorrect method override

* Fix tests, processing, modular, convert

* fix tokenization auto

* hotfix: manually fix processor -> fixme once convert modular is fixed

* fix: convert weights working

* feat: rename and improve convert weight script

* feat: tweaks

* fest: remove `device` input for `post_process_retrieval`

* refactor: remove unused `get_torch_device`

* Fix all tests

* docs: update ColPali model doc

* wip: fix convert weights to hf

* fix logging modular

* docs: add acknowledgements in model doc

* docs: add missing docstring to ColPaliProcessor

* docs: tweak

* docs: add doc for `ColPaliForRetrievalOutput.forward`

* feat: add modifications from colpali-engine v0.3.2 in ColPaliProcessor

* fix: fix and upload colapli hf weights

* refactor: rename `post_process_retrieval` to `score_retrieval`

* fix: fix wrong typing for `score_retrieval`

* test: add integration test for ColPali

* chore: rerun convert modular

* build: fix root imports

* Update docs/source/en/index.md

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>

* fix: address PR comments

* wip: reduce the prediction gap in weight conversion

* docs: add comment in weight conversion script

* docs: add example for `ColPaliForRetrieval.forward`

* tests: change dataset path to the new one in hf-internal

* fix: colpali weight conversion works

* test: add fine-grained check for ColPali integration test

* fix: fix typos in convert weight script

* docs: move input docstring in a variable

* fix: remove hardcoded torch device in test

* fix: run the new modular refactor

* docs: fix python example for ColPali

* feat: add option to choose `score_retrieval`'s output dtype and device

* docs: update doc for `score_retrieval`

* feat: add `patch_size` property in ColPali model

* chore: run `make fix-copies`

* docs: update description for ColPali cookbooks

* fix: remove `ignore_index` methods

* feat: remove non-transformers specific methods

* feat: update `__init__.py` to new hf format

* fix: fix root imports in transformers

* feat: remove ColPali's inheritance from PaliGemma

* Fix CI issues

* nit remove prints

* feat: remove ColPali config and model from `modular_colpali.py`

* feat: add `ColPaliPreTrainedModel` and update modeling and configuration code

* fix: fix auto-removed imports in root `__init__.py`

* fix: various fixes

* fix: fix `_init_weight`

* temp: comment `AutoModel.from_config` for experiments

* fix: add missing `output_attentions` arg in ColPali's forward

* fix: fix `resize_token_embeddings`

* fix: make `input_ids` optional in forward

* feat: rename `projection_layer` to `embedding_proj_layer`

* wip: fix convert colpali weight script

* fix tests and convert weights from original repo

* fix unprotected import

* fix unprotected torch import

* fix style

* change vlm_backbone_config to vlm_config

* fix unprotected import in modular this time

* fix: load config from Hub + tweaks in convert weight script

* docs: move example usage from model docstring to model markdown

* docs: fix input docstring for ColPali's forward method

* fix: use `sub_configs` for ColPaliConfig

* fix: remove non-needed sanity checks in weight conversion script + tweaks

* fix: fix issue with `replace_return_docstrings` in ColPali's `forward`

* docs: update docstring for `ColPaliConfig`

* test: change model path in ColPali test

* fix: fix ColPaliConfig

* fix: fix weight conversion script

* test: fix expected weights for ColPali model

* docs: update ColPali markdown

* docs: fix minor typo in ColPaliProcessor

* Fix tests and add _no_split_modules

* add text_config to colpali config

* [run slow] colpali

* move inputs to torch_device in integration test

* skip test_model_parallelism

* docs: clarify quickstart snippet in ColPali's model card

* docs: update ColPali's model card

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Tony Wu
2024-12-17 11:26:43 +01:00
committed by GitHub
parent a7f5479b45
commit f33a0cebb3
22 changed files with 2204 additions and 2 deletions

View File

@@ -834,6 +834,8 @@
title: CLIPSeg
- local: model_doc/clvp
title: CLVP
- local: model_doc/colpali
title: ColPali
- local: model_doc/data2vec
title: Data2Vec
- local: model_doc/deplot

View File

@@ -100,6 +100,7 @@ Flax), PyTorch, and/or TensorFlow.
| [CodeLlama](model_doc/code_llama) | ✅ | ❌ | ✅ |
| [Cohere](model_doc/cohere) | ✅ | ❌ | ❌ |
| [Cohere2](model_doc/cohere2) | ✅ | ❌ | ❌ |
| [ColPali](model_doc/colpali) | ✅ | ❌ | ❌ |
| [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ |
| [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ |
| [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ |

View File

@@ -0,0 +1,95 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# ColPali
## Overview
The ColPali model was proposed in [ColPali: Efficient Document Retrieval with Vision Language Models](https://doi.org/10.48550/arXiv.2407.01449) by **Manuel Faysse***, **Hugues Sibille***, **Tony Wu***, Bilel Omrani, Gautier Viaud, Céline Hudelot, Pierre Colombo (* denotes equal contribution).
With our new model *ColPali*, we propose to leverage VLMs to construct efficient multi-vector embeddings in the visual space for document retrieval. By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. We train the model to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method.
Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account both the textual and visual content (layout, charts, ...) of a document. ColPali is also highly interpretable: similarity maps can be obtained between patches and query tokens. These maps highlight ColPalis strong OCR capabilities and chart understanding.
**Paper abstract:**
> Documents are visually rich structures that convey information through text, but also figures, page layouts, tables, or even fonts. Since modern retrieval systems mainly rely on the textual information they extract from document pages to index documents -often through lengthy and brittle processes-, they struggle to exploit key visual cues efficiently. This limits their capabilities in many practical document retrieval applications such as Retrieval Augmented Generation (RAG). To benchmark current systems on visually rich document retrieval, we introduce the Visual Document Retrieval Benchmark *ViDoRe*, composed of various page-level retrieval tasks spanning multiple domains, languages, and practical settings. The inherent complexity and performance shortcomings of modern systems motivate a new concept; doing document retrieval by directly embedding the images of the document pages. We release *ColPali*, a Vision Language Model trained to produce high-quality multi-vector embeddings from images of document pages. Combined with a late interaction matching mechanism, *ColPali* largely outperforms modern document retrieval pipelines while being drastically simpler, faster and end-to-end trainable.
>
> We release models, data, code and benchmarks under open licenses at [https://huggingface.co/vidore](https://huggingface.co/vidore).
## Resources
- The official blog post detailing ColPali can be found [here](https://huggingface.co/blog/manu/colpali). 📝
- The original model implementation code for the ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎
- Cookbooks for learning to use the transformers-native version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚
This model was contributed by [@tonywu71](https://huggingface.co/tonywu71) and [@yonigozlan](https://huggingface.co/yonigozlan).
## Usage
This example demonstrates how to use ColPali to embed both queries and images, calculate their similarity scores, and identify the most relevant matches. For a specific query, you can retrieve the top-k most similar images by selecting the ones with the highest similarity scores.
```python
import torch
from PIL import Image
from transformers import ColPaliForRetrieval, ColPaliProcessor
model_name = "vidore/colpali-v1.2-hf"
model = ColPaliForRetrieval.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="cuda:0", # or "mps" if on Apple Silicon
).eval()
processor = ColPaliProcessor.from_pretrained(model_name)
# Your inputs (replace dummy images with screenshots of your documents)
images = [
Image.new("RGB", (32, 32), color="white"),
Image.new("RGB", (16, 16), color="black"),
]
queries = [
"What is the organizational structure for our R&D department?",
"Can you provide a breakdown of last years financial performance?",
]
# Process the inputs
batch_images = processor(images=images).to(model.device)
batch_queries = processor(text=queries).to(model.device)
# Forward pass
with torch.no_grad():
image_embeddings = model(**batch_images)
query_embeddings = model(**batch_queries)
# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)
```
## ColPaliConfig
[[autodoc]] ColPaliConfig
## ColPaliProcessor
[[autodoc]] ColPaliProcessor
## ColPaliForRetrieval
[[autodoc]] ColPaliForRetrieval
- forward

View File

@@ -306,6 +306,10 @@ _import_structure = {
],
"models.cohere": ["CohereConfig"],
"models.cohere2": ["Cohere2Config"],
"models.colpali": [
"ColPaliConfig",
"ColPaliProcessor",
],
"models.conditional_detr": ["ConditionalDetrConfig"],
"models.convbert": [
"ConvBertConfig",
@@ -1468,6 +1472,7 @@ else:
"MODEL_FOR_OBJECT_DETECTION_MAPPING",
"MODEL_FOR_PRETRAINING_MAPPING",
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_RETRIEVAL_MAPPING",
"MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
@@ -1789,6 +1794,12 @@ else:
)
_import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"])
_import_structure["models.cohere2"].extend(["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"])
_import_structure["models.colpali"].extend(
[
"ColPaliForRetrieval",
"ColPaliPreTrainedModel",
]
)
_import_structure["models.conditional_detr"].extend(
[
"ConditionalDetrForObjectDetection",
@@ -5207,6 +5218,10 @@ if TYPE_CHECKING:
)
from .models.cohere import CohereConfig
from .models.cohere2 import Cohere2Config
from .models.colpali import (
ColPaliConfig,
ColPaliProcessor,
)
from .models.conditional_detr import (
ConditionalDetrConfig,
)
@@ -6413,6 +6428,7 @@ if TYPE_CHECKING:
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_RETRIEVAL_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
@@ -6689,6 +6705,10 @@ if TYPE_CHECKING:
Cohere2Model,
Cohere2PreTrainedModel,
)
from .models.colpali import (
ColPaliForRetrieval,
ColPaliPreTrainedModel,
)
from .models.conditional_detr import (
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,

View File

@@ -53,6 +53,7 @@ from . import (
codegen,
cohere,
cohere2,
colpali,
conditional_detr,
convbert,
convnext,

View File

@@ -74,6 +74,7 @@ else:
"MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING",
"MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING",
"MODEL_FOR_VISION_2_SEQ_MAPPING",
"MODEL_FOR_RETRIEVAL_MAPPING",
"MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING",
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING",
"MODEL_MAPPING",
@@ -252,6 +253,7 @@ if TYPE_CHECKING:
MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_RETRIEVAL_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,

View File

@@ -70,6 +70,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("codegen", "CodeGenConfig"),
("cohere", "CohereConfig"),
("cohere2", "Cohere2Config"),
("colpali", "ColPaliConfig"),
("conditional_detr", "ConditionalDetrConfig"),
("convbert", "ConvBertConfig"),
("convnext", "ConvNextConfig"),
@@ -373,6 +374,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("codegen", "CodeGen"),
("cohere", "Cohere"),
("cohere2", "Cohere2"),
("colpali", "ColPali"),
("conditional_detr", "Conditional DETR"),
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),

View File

@@ -306,6 +306,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("big_bird", "BigBirdForPreTraining"),
("bloom", "BloomForCausalLM"),
("camembert", "CamembertForMaskedLM"),
("colpali", "ColPaliForRetrieval"),
("ctrl", "CTRLLMHeadModel"),
("data2vec-text", "Data2VecTextForMaskedLM"),
("deberta", "DebertaForMaskedLM"),
@@ -775,6 +776,12 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_RETRIEVAL_MAPPING_NAMES = OrderedDict(
[
("colpali", "ColPaliForRetrieval"),
]
)
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
[
("aria", "AriaForConditionalGeneration"),
@@ -1473,6 +1480,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FO
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
)
MODEL_FOR_RETRIEVAL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_RETRIEVAL_MAPPING_NAMES)
MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES
)

View File

@@ -58,6 +58,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("clip", "CLIPProcessor"),
("clipseg", "CLIPSegProcessor"),
("clvp", "ClvpProcessor"),
("colpali", "ColPaliProcessor"),
("flava", "FlavaProcessor"),
("fuyu", "FuyuProcessor"),
("git", "GitProcessor"),

View File

@@ -148,6 +148,7 @@ else:
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
(
"cpm",

View File

@@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_colpali import *
from .modeling_colpali import *
from .processing_colpali import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ColPali model configuration"""
import logging
from copy import deepcopy
from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig
logger = logging.getLogger(__name__)
class ColPaliConfig(PretrainedConfig):
r"""
Configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an instance
of `ColPaliForRetrieval` according to the specified arguments, defining the model architecture following the methodology
from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
The ColPali config is very similar to [`PaligemmaConfig`], but with an extra attribute defining the embedding dimension.
Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vlm_config (`PretrainedConfig`, *optional*):
Configuration of the VLM backbone model.
text_config (`PretrainedConfig`, *optional*):
Configuration of the text backbone model. Overrides the `text_config` attribute of the `vlm_config` if provided.
embedding_dim (`int`, *optional*, defaults to 128):
Dimension of the multi-vector embeddings produced by the model.
Example:
```python
from transformers.models.colpali import ColPaliConfig, ColPaliForRetrieval
config = ColPaliConfig()
model = ColPaliForRetrieval(config)
```
"""
model_type = "colpali"
sub_configs = {"vlm_config": PretrainedConfig, "text_config": AutoConfig}
def __init__(
self,
vlm_config=None,
text_config=None,
embedding_dim: int = 128,
**kwargs,
):
if vlm_config is None:
vlm_config = CONFIG_MAPPING["paligemma"]()
logger.info(
"`vlm_config` is `None`. Initializing `vlm_config` with the `PaliGemmaConfig` with default values."
)
elif isinstance(vlm_config, dict):
vlm_config = deepcopy(vlm_config)
if "model_type" not in vlm_config:
raise KeyError(
"The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
)
elif vlm_config["model_type"] not in CONFIG_MAPPING:
raise ValueError(
f"The model type `{vlm_config['model_type']}` is not supported. Please provide a valid model type."
)
vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
elif isinstance(vlm_config, PretrainedConfig):
vlm_config = vlm_config
else:
raise TypeError(
f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
)
self.vlm_config = vlm_config
self.text_config = text_config = text_config if text_config is not None else vlm_config.text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
self.embedding_dim = embedding_dim
super().__init__(**kwargs)
__all__ = ["ColPaliConfig"]

View File

@@ -0,0 +1,207 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert ColPali weights from the original repository to the HF model format.
Original repository: https://github.com/illuin-tech/colpali.
NOTE: This script was originally run using `torch==2.5.1` and with:
```bash
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.2-merged \
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.2-hf-internal \
--push_to_hub
```
"""
import argparse
import glob
from pathlib import Path
from typing import Any, Dict, Optional
import torch
from huggingface_hub import snapshot_download
from safetensors import safe_open
from transformers import AutoConfig
from transformers.models.colpali import ColPaliForRetrieval
from transformers.models.colpali.configuration_colpali import ColPaliConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
ORIGINAL_DTYPE = torch.bfloat16
def rename_state_dict_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
new_state_dict = {}
for key, value in state_dict.items():
new_key = key
if key.startswith("custom_text_proj"):
new_key = key.replace("custom_text_proj", "embedding_proj_layer")
if key.startswith("model."):
new_key = key.replace("model.", "vlm.", 1)
new_state_dict[new_key] = value
return new_state_dict
def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> Dict[str, torch.Tensor]:
directory_path = snapshot_download(
repo_id=model_id,
revision=revision,
allow_patterns=["*.safetensors"],
)
original_state_dict = {}
for path in glob.glob(f"{directory_path}/*"):
if path.endswith(".safetensors"):
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
original_state_dict[key] = f.get_tensor(key)
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
if "lm_head.weight" not in original_state_dict:
original_state_dict["vlm.language_model.lm_head.weight"] = original_state_dict[
"model.language_model.model.embed_tokens.weight"
].clone()
return original_state_dict
@torch.no_grad()
def convert_colpali_weights_to_hf(
model_id: str,
output_dir: str,
push_to_hub: bool,
revision: Optional[str] = None,
original_vlm_name_or_path: Optional[str] = None,
):
# Load the original model data
original_config = AutoConfig.from_pretrained(
model_id,
revision=revision,
)
if original_vlm_name_or_path is not None:
original_config._name_or_path = original_vlm_name_or_path
if hasattr(original_config, "architectures"):
delattr(original_config, "architectures")
original_state_dict = load_original_state_dict(model_id, revision=revision)
# Format the state_dict keys
original_state_dict = rename_state_dict_keys(original_state_dict)
# Create the new config
config = ColPaliConfig(
vlm_config=original_config,
embedding_dim=128, # hardcoded in the original model
)
config.model_type = "colpali"
config.is_composition = False
# Load the untrained model
model = ColPaliForRetrieval(config=config).to("cpu").eval()
print("Created model with new config and randomly initialized weights")
# NOTE: The model was initialized with float32 weights. We need to convert it to the desired precision.
# There are two ways to set the model's dtype:
# - Using `model.from_pretrained(..., torch_dtype=dtype_precision)` doesn't convert the hyperparameters to the desired precision.
# - Using `model.to(dtype_precision)` converts all values - including the hyperparameters - to the desired precision.
# The following snippet allows a fine-grained control over the model's dtype, making sure that all
# the new weights' dtypes match the original model.
for param in model.parameters():
param.data = param.data.to(ORIGINAL_DTYPE)
print(f"Converted the new model weights to `{ORIGINAL_DTYPE}`")
# Load the original weights
model.load_state_dict(original_state_dict)
print("Loaded original model weights")
# Tie the weights (following ColPali's `__init__`` step)
if model.vlm.language_model._tied_weights_keys is not None:
model._tied_weights_keys = [f"vlm.language_model.{k}" for k in model.vlm.language_model._tied_weights_keys]
# Sanity check: ensure all keys are the same
state_dict_keys_old = set(original_state_dict.keys())
state_dict_keys_new = set(model.state_dict().keys())
disjoint_keys = state_dict_keys_old.symmetric_difference(state_dict_keys_new)
if disjoint_keys:
raise ValueError(f"Incompatible keys: {disjoint_keys}")
# Save the model
if push_to_hub:
model.push_to_hub(output_dir, private=True)
print(f"Model pushed to the hub at `{output_dir}`")
else:
Path(output_dir).mkdir(exist_ok=True, parents=True)
model.save_pretrained(output_dir)
print(f"Model saved to `{output_dir}`")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script converts the original ColPali model to the HF model format.
Example usage:
```bash
python src/transformers/models/colpali/convert_colpali_weights_to_hf.py \
--model_id vidore/colpali-v1.2-merged \
--revision 89fd9736194236a1ecb7a9ec9b04f537f6f896af \
--original_vlm_name_or_path google/paligemma-3b-mix-448 \
--output_dir vidore/colpali-v1.2-hf \
--push_to_hub
```
"""
)
parser.add_argument(
"--model_id",
help="Model ID of the original model to convert",
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--push_to_hub",
help="Whether or not to push the model to the hub at `output_dir` instead of saving it locally",
action="store_true",
default=False,
)
parser.add_argument(
"--revision",
help="Revision of the model to download",
default=None,
)
parser.add_argument(
"--original_vlm_name_or_path",
help="Name or path of the original VLM backbone model",
default=None,
)
args = parser.parse_args()
convert_colpali_weights_to_hf(
model_id=args.model_id,
output_dir=args.output_dir,
push_to_hub=args.push_to_hub,
revision=args.revision,
original_vlm_name_or_path=args.original_vlm_name_or_path,
)

View File

@@ -0,0 +1,299 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch ColPali model"""
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import AutoModelForImageTextToText
from ...cache_utils import Cache
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from .configuration_colpali import ColPaliConfig
_CONFIG_FOR_DOC = "ColPaliConfig"
COLPALI_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`ColPaliConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare ColPali model outputting raw hidden-states without any specific head on top.",
COLPALI_START_DOCSTRING,
)
class ColPaliPreTrainedModel(PreTrainedModel):
config_class = ColPaliConfig
base_model_prefix = "model"
_no_split_modules = []
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.vlm_config.text_config.initializer_range
)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@dataclass
class ColPaliForRetrievalOutput(ModelOutput):
"""
Base class for ColPali embeddings output.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The embeddings of the model.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
image_hidden_states (`torch.FloatTensor`, *optional*):
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
"""
loss: Optional[torch.FloatTensor] = None
embeddings: torch.Tensor = None
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
image_hidden_states: Optional[torch.FloatTensor] = None
COLPALI_FOR_RETRIEVAL_INPUT_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
The tensors corresponding to the input images. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`SiglipImageProcessor.__call__`] for details ([]`PaliGemmaProcessor`] uses
[`SiglipImageProcessor`] for processing images). If none, ColPali will only process text (query embeddings).
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the vlm backbone model.
"""
@add_start_docstrings(
"""
ColPali leverages Vision Language Models (VLMs) to construct efficient multi-vector embeddings in the visual space for document retrieval.
By feeding the ViT output patches from PaliGemma-3B to a linear projection, we create a multi-vector representation of documents. The model
is trained to maximize the similarity between these document embeddings and the query embeddings, following the ColBERT method.
Using ColPali removes the need for potentially complex and brittle layout recognition and OCR pipelines with a single model that can take into account
both the textual and visual content (layout, charts, ...) of a document.
ColPali was introduced in the following paper: [*ColPali: Efficient Document Retrieval with Vision Language Models*](https://arxiv.org/abs/2407.01449).
Resources:
- A blog post detailing ColPali, a vision retrieval model, can be found [here](https://huggingface.co/blog/manu/colpali). 📝
- The code for using and training the original ColPali model and for the `colpali-engine` package can be found [here](https://github.com/illuin-tech/colpali). 🌎
- Cookbooks for learning to use the Hf version of ColPali, fine-tuning, and similarity maps generation can be found [here](https://github.com/tonywu71/colpali-cookbooks). 📚
"""
)
class ColPaliForRetrieval(ColPaliPreTrainedModel):
def __init__(self, config: ColPaliConfig):
super().__init__(config)
self.config = config
self.vocab_size = config.vlm_config.text_config.vocab_size
vlm = AutoModelForImageTextToText.from_config(config.vlm_config)
if vlm.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"vlm.language_model.{k}" for k in vlm.language_model._tied_weights_keys]
self.vlm = vlm
self.embedding_dim = self.config.embedding_dim
self.embedding_proj_layer = nn.Linear(
self.config.vlm_config.text_config.hidden_size,
self.embedding_dim,
)
self.post_init()
@add_start_docstrings_to_model_forward(COLPALI_FOR_RETRIEVAL_INPUT_DOCSTRING)
@replace_return_docstrings(output_type=ColPaliForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, ColPaliForRetrievalOutput]:
r"""
Returns:
"""
if "pixel_values" in kwargs:
kwargs["pixel_values"] = kwargs["pixel_values"].to(dtype=self.dtype)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vlm(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=return_dict,
output_attentions=output_attentions,
**kwargs,
)
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
embeddings = self.embedding_proj_layer(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
loss = None
if not return_dict:
output = (embeddings,) + outputs[2:]
output[2] = output[2] if output_hidden_states is not None else None
output[-1] = (outputs.image_hidden_states if pixel_values is not None else None,)
return (loss,) + output if loss is not None else output
return ColPaliForRetrievalOutput(
loss=loss,
embeddings=embeddings,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states if pixel_values is not None else None,
)
def get_input_embeddings(self):
return self.vlm.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.vlm.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.vlm.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.vlm.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.vlm.language_model.set_decoder(decoder)
def get_decoder(self):
return self.vlm.language_model.get_decoder()
def tie_weights(self):
return self.vlm.language_model.tie_weights()
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
model_embeds = self.vlm.language_model.resize_token_embeddings(
new_num_tokens=new_num_tokens,
pad_to_multiple_of=pad_to_multiple_of,
mean_resizing=mean_resizing,
)
self.config.vlm_config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vlm_config.vocab_size = model_embeds.num_embeddings
self.vlm.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds
__all__ = [
"ColPaliForRetrieval",
"ColPaliForRetrievalOutput",
"ColPaliPreTrainedModel",
]

View File

@@ -0,0 +1,354 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Optional, Union
from transformers.models.paligemma.processing_paligemma import (
IMAGE_TOKEN,
PaliGemmaProcessor,
build_string_from_input,
make_batched_images,
)
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import (
ProcessingKwargs,
Unpack,
)
from ...tokenization_utils_base import (
PreTokenizedInput,
TextInput,
)
from ...utils import (
is_torch_available,
logging,
)
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": "longest",
},
"images_kwargs": {
"data_format": "channels_first",
"do_convert_rgb": True,
},
"common_kwargs": {"return_tensors": "pt"},
}
class ColPaliProcessor(PaliGemmaProcessor):
r"""
Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
well as to compute the late-interaction retrieval score.
[`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
for more information.
Args:
image_processor ([`SiglipImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
visual_prompt_prefix: ClassVar[str] = "Describe the image."
query_prefix: ClassVar[str] = "Question: "
@property
def query_augmentation_token(self) -> str:
"""
Return the query augmentation token.
Query augmentation buffers are used as reasoning buffers during inference.
"""
return self.tokenizer.pad_token
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
both text and images at the same time.
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
[`~LlamaTokenizerFast.__call__`].
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
[`~SiglipImageProcessor.__call__`].
Please refer to the doctsring of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ColPaliProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
return_token_type_ids = True if suffix is not None else False
if text is None and images is None:
raise ValueError("Either text or images must be provided")
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
if images is not None:
if is_valid_image(images):
images = [images]
elif isinstance(images, list) and is_valid_image(images[0]):
pass
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")
texts_doc = [self.visual_prompt_prefix] * len(images)
images = [image.convert("RGB") for image in images]
input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
num_images=len(image_list) if isinstance(image_list, list) else 1,
)
for prompt, image_list in zip(texts_doc, images)
]
images = make_batched_images(images)
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
# max_length has to account for the image tokens
if output_kwargs["text_kwargs"].get("max_length", None) is not None:
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
inputs = self.tokenizer(
input_strings,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return_data = {**inputs, "pixel_values": pixel_values}
if return_token_type_ids:
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
return BatchFeature(data=return_data)
elif text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, list) and isinstance(text[0], str)):
raise ValueError("Text must be a string or a list of strings")
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
for query in text:
query = self.tokenizer.bos_token + self.query_prefix + query
query += suffix # add suffix (pad tokens)
query += "\n" # make input ISO to PaliGemma's processor
texts_query.append(query)
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return batch_query
def process_images(
self,
images: ImageInput = None,
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
return self.__call__(images=images, **kwargs)
def process_queries(
self,
text: Union[TextInput, List[TextInput]],
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
"""
return self.__call__(text=text, **kwargs)
def score_retrieval(
self,
query_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
batch_size: int = 128,
output_dtype: Optional["torch.dtype"] = None,
output_device: Union["torch.device", str] = "cpu",
) -> "torch.Tensor":
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
image of a document page.
Because the embedding tensors are multi-vector and can thus have different shapes, they
should be fed as:
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
obtained by padding the list of tensors.
Args:
query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
If `None`, the dtype of the input embeddings is used.
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
Returns:
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
tensor is saved on the "cpu" device.
"""
if len(query_embeddings) == 0:
raise ValueError("No queries provided")
if len(passage_embeddings) == 0:
raise ValueError("No passages provided")
if query_embeddings[0].device != passage_embeddings[0].device:
raise ValueError("Queries and passages must be on the same device")
if query_embeddings[0].dtype != passage_embeddings[0].dtype:
raise ValueError("Queries and passages must have the same dtype")
if output_dtype is None:
output_dtype = query_embeddings[0].dtype
scores: List[torch.Tensor] = []
for i in range(0, len(query_embeddings), batch_size):
batch_scores: List[torch.Tensor] = []
batch_queries = torch.nn.utils.rnn.pad_sequence(
query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
)
for j in range(0, len(passage_embeddings), batch_size):
batch_passages = torch.nn.utils.rnn.pad_sequence(
passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
)
batch_scores.append(
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
)
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
return torch.cat(scores, dim=0)
__all__ = [
"ColPaliProcessor",
]

View File

@@ -0,0 +1,443 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/colpali/modular_colpali.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_colpali.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import ClassVar, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
from ...utils import is_torch_available
if is_torch_available():
import torch
class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": "longest",
},
"images_kwargs": {
"data_format": "channels_first",
"do_convert_rgb": True,
},
"common_kwargs": {"return_tensors": "pt"},
}
IMAGE_TOKEN = "<image>"
EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]
def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
"""
Builds a string from the input prompt and image tokens.
For example, for the call:
build_string_from_input(
prompt="Prefix str"
bos_token="<s>",
image_seq_len=3,
image_token="<im>",
)
The output will be:
"<im><im><im><s>Initial str"
Args:
prompt (`List[Union[str, ImageInput]]`): The input prompt.
bos_token (`str`): The beginning of sentence token.
image_seq_len (`int`): The length of the image sequence.
image_token (`str`): The image token.
num_images (`int`): Number of images in the prompt.
"""
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images
elif is_valid_image(images):
return [images]
raise ValueError(f"Could not make batched video from {images}")
class ColPaliProcessor(ProcessorMixin):
r"""
Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
well as to compute the late-interaction retrieval score.
[`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
for more information.
Args:
image_processor ([`SiglipImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "SiglipImageProcessor"
tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
visual_prompt_prefix: ClassVar[str] = "Describe the image."
query_prefix: ClassVar[str] = "Question: "
def __init__(
self,
image_processor=None,
tokenizer=None,
chat_template=None,
**kwargs,
):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
if not hasattr(image_processor, "image_seq_length"):
raise ValueError("Image processor is missing an `image_seq_length` attribute.")
self.image_seq_length = image_processor.image_seq_length
if not hasattr(tokenizer, "image_token"):
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
else:
self.image_token_id = tokenizer.image_token_id
tokenizer.add_tokens(EXTRA_TOKENS)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is custom
wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
both text and images at the same time.
When preparing the the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
[`~LlamaTokenizerFast.__call__`].
When preparing the the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
[`~SiglipImageProcessor.__call__`].
Please refer to the doctsring of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
output_kwargs = self._merge_kwargs(
ColPaliProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
suffix = output_kwargs["text_kwargs"].pop("suffix", None)
return_token_type_ids = True if suffix is not None else False
if text is None and images is None:
raise ValueError("Either text or images must be provided")
if text is not None and images is not None:
raise ValueError("Only one of text or images can be processed at a time")
if images is not None:
if is_valid_image(images):
images = [images]
elif isinstance(images, list) and is_valid_image(images[0]):
pass
elif not (isinstance(images, list) and isinstance(images[0], list) and is_valid_image(images[0][0])):
raise ValueError("images must be an image, list of images or list of list of images")
texts_doc = [self.visual_prompt_prefix] * len(images)
images = [image.convert("RGB") for image in images]
input_strings = [
build_string_from_input(
prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length,
image_token=IMAGE_TOKEN,
num_images=len(image_list) if isinstance(image_list, list) else 1,
)
for prompt, image_list in zip(texts_doc, images)
]
images = make_batched_images(images)
pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
# max_length has to account for the image tokens
if output_kwargs["text_kwargs"].get("max_length", None) is not None:
output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
inputs = self.tokenizer(
input_strings,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return_data = {**inputs, "pixel_values": pixel_values}
if return_token_type_ids:
labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
return_data.update({"labels": labels})
return BatchFeature(data=return_data)
elif text is not None:
if isinstance(text, str):
text = [text]
elif not (isinstance(text, list) and isinstance(text[0], str)):
raise ValueError("Text must be a string or a list of strings")
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
for query in text:
query = self.tokenizer.bos_token + self.query_prefix + query
query += suffix # add suffix (pad tokens)
query += "\n" # make input ISO to PaliGemma's processor
texts_query.append(query)
output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
batch_query = self.tokenizer(
texts_query,
return_token_type_ids=False,
**output_kwargs["text_kwargs"],
)
return batch_query
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
@property
def query_augmentation_token(self) -> str:
"""
Return the query augmentation token.
Query augmentation buffers are used as reasoning buffers during inference.
"""
return self.tokenizer.pad_token
def process_images(
self,
images: ImageInput = None,
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's [`~SiglipImageProcessor.__call__`].
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
return self.__call__(images=images, **kwargs)
def process_queries(
self,
text: Union[TextInput, List[TextInput]],
**kwargs: Unpack[ColPaliProcessorKwargs],
) -> BatchFeature:
"""
Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
[`ColPaliProcessor.__call__`].
This method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`].
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
"""
return self.__call__(text=text, **kwargs)
def score_retrieval(
self,
query_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
passage_embeddings: Union["torch.Tensor", List["torch.Tensor"]],
batch_size: int = 128,
output_dtype: Optional["torch.dtype"] = None,
output_device: Union["torch.device", str] = "cpu",
) -> "torch.Tensor":
"""
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
image of a document page.
Because the embedding tensors are multi-vector and can thus have different shapes, they
should be fed as:
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
obtained by padding the list of tensors.
Args:
query_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings.
passage_embeddings (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings.
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
If `None`, the dtype of the input embeddings is used.
output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
Returns:
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
tensor is saved on the "cpu" device.
"""
if len(query_embeddings) == 0:
raise ValueError("No queries provided")
if len(passage_embeddings) == 0:
raise ValueError("No passages provided")
if query_embeddings[0].device != passage_embeddings[0].device:
raise ValueError("Queries and passages must be on the same device")
if query_embeddings[0].dtype != passage_embeddings[0].dtype:
raise ValueError("Queries and passages must have the same dtype")
if output_dtype is None:
output_dtype = query_embeddings[0].dtype
scores: List[torch.Tensor] = []
for i in range(0, len(query_embeddings), batch_size):
batch_scores: List[torch.Tensor] = []
batch_queries = torch.nn.utils.rnn.pad_sequence(
query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
)
for j in range(0, len(passage_embeddings), batch_size):
batch_passages = torch.nn.utils.rnn.pad_sequence(
passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
)
batch_scores.append(
torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
)
scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
return torch.cat(scores, dim=0)
__all__ = ["ColPaliProcessor"]

View File

@@ -813,6 +813,9 @@ MODEL_FOR_PRETRAINING_MAPPING = None
MODEL_FOR_QUESTION_ANSWERING_MAPPING = None
MODEL_FOR_RETRIEVAL_MAPPING = None
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = None
@@ -2258,6 +2261,20 @@ class Cohere2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class ColPaliForRetrieval(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ColPaliPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@@ -0,0 +1,368 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ColPali model."""
import gc
import unittest
from typing import ClassVar
import torch
from datasets import load_dataset
from parameterized import parameterized
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from transformers import (
is_torch_available,
is_vision_available,
)
from transformers.models.colpali.configuration_colpali import ColPaliConfig
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from transformers.testing_utils import (
require_torch,
require_torch_sdpa,
require_vision,
slow,
torch_device,
)
if is_torch_available():
import torch
if is_vision_available():
pass
class ColPaliForRetrievalModelTester:
def __init__(
self,
parent,
ignore_index=-100,
image_token_index=0,
projector_hidden_act="gelu",
seq_length=25,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
projection_dim=32,
text_config={
"model_type": "gemma",
"seq_length": 128,
"is_training": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"head_dim": 8,
"intermediate_size": 37,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 1,
},
is_training=False,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_image_tokens": 4,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
use_cache=False,
embedding_dim=128,
):
self.parent = parent
self.ignore_index = ignore_index
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.projection_dim = projection_dim
self.pad_token_id = text_config["pad_token_id"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.encoder_seq_length = seq_length
self.use_cache = use_cache
self.embedding_dim = embedding_dim
self.vlm_config = {
"model_type": "paligemma",
"text_config": self.text_config,
"vision_config": self.vision_config,
"ignore_index": self.ignore_index,
"image_token_index": self.image_token_index,
"projector_hidden_act": self.projector_hidden_act,
"projection_dim": self.projection_dim,
"vision_feature_select_strategy": self.vision_feature_select_strategy,
"vision_feature_layer": self.vision_feature_layer,
}
def get_config(self):
return ColPaliConfig(
vlm_config=self.vlm_config,
embedding_dim=self.embedding_dim,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.vlm_config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.vlm_config.image_token_index] = self.pad_token_id
input_ids[:, :16] = config.vlm_config.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"token_type_ids": torch.zeros_like(input_ids),
}
return config, inputs_dict
@require_torch
class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `ColPaliForRetrieval`.
"""
all_model_classes = (ColPaliForRetrieval,) if is_torch_available() else ()
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = ColPaliForRetrievalModelTester(self)
self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
wte = model.get_input_embeddings()
inputs["inputs_embeds"] = wte(input_ids)
with torch.no_grad():
model(**inputs)
# overwrite inputs_embeds tests because we need to delete "pixel values" for LVLMs
# while some other models require pixel_values to be present
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
input_ids = inputs["input_ids"]
del inputs["input_ids"]
del inputs["pixel_values"]
inputs_embeds = model.get_input_embeddings()(input_ids)
with torch.no_grad():
out_ids = model(input_ids=input_ids, **inputs)[0]
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@slow
@require_vision
def test_colpali_forward_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@require_torch_sdpa
@slow
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
def test_eager_matches_sdpa_inference(self, torch_dtype: str):
self.skipTest(
"Due to custom causal mask, there is a slightly too big difference between eager and sdpa in bfloat16."
)
@unittest.skip(
reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now."
)
def test_model_parallelism(self):
pass
@unittest.skip(
reason="PaliGemmma's SigLip encoder uses the same initialization scheme as the Flax original implementation"
)
def test_initialization(self):
pass
# TODO extend valid outputs to include this test @Molbap
@unittest.skip(reason="PaliGemma has currently one output format.")
def test_model_outputs_equivalence(self):
pass
@unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
def test_sdpa_can_compile_dynamic(self):
pass
@require_torch
class ColPaliModelIntegrationTest(unittest.TestCase):
model_name: ClassVar[str] = "vidore/colpali-v1.2-hf"
def setUp(self):
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
def test_model_integration_test(self):
"""
Test if the model is able to retrieve the correct pages for a small and easy dataset.
"""
model = ColPaliForRetrieval.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
).eval()
# Load the test dataset
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
# Preprocess the examples
batch_images = self.processor(images=ds["image"]).to(torch_device)
batch_queries = self.processor(text=ds["query"]).to(torch_device)
# Run inference
with torch.inference_mode():
image_embeddings = model(**batch_images).embeddings
query_embeddings = model(**batch_queries).embeddings
# Compute retrieval scores
scores = self.processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
) # (len(qs), len(ps))
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
# Check if the maximum scores per row are in the diagonal of the matrix score
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original implementation
expected_scores = torch.tensor(
[
[15.5625, 6.5938, 14.4375],
[12.2500, 16.2500, 11.0000],
[15.0625, 11.7500, 21.0000],
],
dtype=scores.dtype,
)
assert torch.allclose(scores, expected_scores, atol=1), f"Expected scores {expected_scores}, got {scores}"

View File

@@ -0,0 +1,247 @@
import shutil
import tempfile
import unittest
import torch
from transformers import GemmaTokenizer
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from transformers.testing_utils import get_tests_dir, require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils.dummy_vision_objects import SiglipImageProcessor
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import (
ColPaliProcessor,
PaliGemmaProcessor,
SiglipImageProcessor,
)
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_vision
class ColPaliProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = ColPaliProcessor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
image_processor.image_seq_length = 0
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True)
processor = PaliGemmaProcessor(image_processor=image_processor, tokenizer=tokenizer)
processor.save_pretrained(self.tmpdirname)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
@require_torch
@require_vision
def test_process_images(self):
# Processor configuration
image_input = self.prepare_image_inputs()
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
image_processor.image_seq_length = 14
# Get the processor
processor = self.processor_class(
tokenizer=tokenizer,
image_processor=image_processor,
)
# Process the image
batch_feature = processor.process_images(images=image_input, return_tensors="pt")
# Assertions
self.assertIn("pixel_values", batch_feature)
self.assertEqual(batch_feature["pixel_values"].shape, torch.Size([1, 3, 384, 384]))
@require_torch
@require_vision
def test_process_queries(self):
# Inputs
queries = [
"Is attention really all you need?",
"Are Benjamin, Antoine, Merve, and Jo best friends?",
]
# Processor configuration
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=112, padding="max_length")
image_processor.image_seq_length = 14
# Get the processor
processor = self.processor_class(
tokenizer=tokenizer,
image_processor=image_processor,
)
# Process the image
batch_feature = processor.process_queries(text=queries, return_tensors="pt")
# Assertions
self.assertIn("input_ids", batch_feature)
self.assertIsInstance(batch_feature["input_ids"], torch.Tensor)
self.assertEqual(batch_feature["input_ids"].shape[0], len(queries))
# The following tests are overwritten as ColPaliProcessor can only take one of images or text as input at a time
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
inputs = processor(text=input_str, return_tensors="pt")
self.assertEqual(inputs[self.text_input_name].shape[-1], 117)
def test_image_processor_defaults_preserved_by_image_kwargs(self):
"""
We use do_rescale=True, rescale_factor=-1 to ensure that image_processor kwargs are preserved in the processor.
We then check that the mean of the pixel_values is less than or equal to 0 after processing.
Since the original pixel_values are in [0, 255], this is a good indicator that the rescale_factor is indeed applied.
"""
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["image_processor"] = self.get_component(
"image_processor", do_rescale=True, rescale_factor=-1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, return_tensors="pt")
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["tokenizer"] = self.get_component("tokenizer", padding="longest")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
inputs = processor(text=input_str, return_tensors="pt", max_length=112, padding="max_length")
self.assertEqual(inputs[self.text_input_name].shape[-1], 112)
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor_components["image_processor"] = self.get_component(
"image_processor", do_rescale=True, rescale_factor=1
)
processor_components["tokenizer"] = self.get_component("tokenizer", max_length=117, padding="max_length")
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
inputs = processor(images=image_input, do_rescale=True, rescale_factor=-1, return_tensors="pt")
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
inputs = processor(
text=input_str,
return_tensors="pt",
do_rescale=True,
rescale_factor=-1,
padding="max_length",
max_length=76,
)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs(batch_size=2)
inputs = processor(
images=image_input,
return_tensors="pt",
do_rescale=True,
rescale_factor=-1,
padding="longest",
max_length=76,
)
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
def test_doubly_passed_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
with self.assertRaises(ValueError):
_ = processor(
images=image_input,
images_kwargs={"do_rescale": True, "rescale_factor": -1},
do_rescale=True,
return_tensors="pt",
)
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(text=input_str, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
image_input = self.prepare_image_inputs()
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"do_rescale": True, "rescale_factor": -1},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}
inputs = processor(images=image_input, **all_kwargs)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)

View File

@@ -87,7 +87,7 @@ def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> str
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch after the two previous regexes.
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
# This is to make sure the transformers module imported is the one in the repo.

View File

@@ -56,7 +56,7 @@ transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
_re_tf_models = re.compile(r"TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_flax_models = re.compile(r"Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
# Will match any TF or Flax model too so need to be in an else branch afterthe two previous regexes.
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration)")
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
# Fill this with tuples (pipeline_tag, model_mapping, auto_model)