From e5a9ce48f711b1f26eef3f7047a13b8235e4a71b Mon Sep 17 00:00:00 2001 From: StevenBucaille Date: Tue, 17 Jun 2025 18:10:23 +0200 Subject: [PATCH] Add LightGlue model (#31718) * init * chore: various changes to LightGlue * chore: various changes to LightGlue * chore: various changes to LightGlue * chore: various changes to LightGlue * Fixed dynamo bug and image padding tests * refactor: applied refactoring changes from SuperGlue's concat, batch and stack functions to LightGlue file * tests: removed sdpa support and changed expected values * chore: added some docs and refactoring * chore: fixed copy to superpoint.image_processing_superpoint.convert_to_grayscale * feat: adding batch implementation * feat: added validation for preprocess and post process method to LightGlueImageProcessor * chore: changed convert_lightglue_to_hf script to comply with new standard * chore: changed lightglue test values to match new lightglue config pushed to hub * chore: simplified convert_lightglue_to_hf conversion map * feat: adding batching implementation * chore: make style * feat: added threshold to post_process_keypoint_matching method * fix: added missing instructions that turns keypoints back to absolute coordinate before matching forward * fix: added typehint and docs * chore: make style * [run-slow] lightglue * fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching * tests: added CUDA proof tests similar to SuperGlue * chore: various changes to modeling_lightglue.py - Added "Copies from" statements for copied functions from modeling_superglue.py - Added missing docstrings - Removed unused functions or classes - Removed unnecessary statements - Added missing typehints - Added comments to the main forward method * chore: various changes to convert_lightglue_to_hf.py - Added model saving - Added model reloading * chore: fixed imports in lightglue files * [run-slow] lightglue * chore: make style * [run-slow] lightglue * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii * [run-slow] lightglue * chore: Applied some suggestions from review - Added missing typehints - Refactor "cuda" to device variable - Variable renaming - LightGlue output order changed - Make style * fix: added missing grayscale argument in image processor in case use of SuperPoint keypoint detector * fix: changed lightglue HF repo to lightglue_superpoint with grayscale default to True * refactor: make keypoints `(batch_size, num_keypoints, keypoint_dim)` through forward and unsqueeze only before attention layer * refactor: refactor do_layer_keypoint_pruning * tests: added tests with no early stop and keypoint pruning * refactor: various refactoring to modeling_lightglue.py - Removed unused functions - Renamed variables for consistency - Added comments for clarity - Set methods to private in LightGlueForKeypointMatching - Replaced tensor initialization to list then concatenation - Used more pythonic list comprehension for repetitive instructions * refactor: added comments and renamed filter_matches to get_matches_from_scores * tests: added copied from statement with superglue tests * docs: added comment to prepare_keypoint_matching_output function in tests * [run-slow] lightglue * refactor: reordered _concat_early_stopped_outputs in LightGlue class * [run-slow] lightglue * docs: added lightglue.md model doc * docs: added Optional typehint to LightGlueKeypointMatchingOutput * chore: removed pad_images function * chore: set do_grayscale default value to True in LightGlueImageProcessor * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii * docs: added missing LightGlueConfig typehint in nn.Module __init__ methods * docs: removed unnecessary code in docs * docs: import SuperPointConfig only from a TYPE_CHECKING context * chore: use PretrainedConfig arguments `num_hidden_layers` and `num_attention_heads` instead of `num_layers` and `num_heads` * chore: added organization as arg in convert_lightglue_to_hf.py script * refactor: set device variable * chore: added "gelu" in LightGlueConfig as hidden_act parameter * docs: added comments to reshape.flip.reshape instruction to perform cross attention * refactor: used batched inference for keypoint detector forward pass * fix: added fix for SDPA tests * docs: fixed docstring for LightGlueImageProcessor * [run-slow] lightglue * refactor: removed unused line * refactor: added missing arguments in LightGlueConfig init method * docs: added missing LightGlueConfig typehint in init methods * refactor: added checkpoint url as default variable to verify models output only if it is the default url * fix: moved print message inside if statement * fix: added log assignment r removal in convert script * fix: got rid of confidence_thresholds as registered buffers * refactor: applied suggestions from SuperGlue PR * docs: changed copyright to 2025 * refactor: modular LightGlue * fix: removed unnecessary import * feat: added plot_keypoint_matching method to LightGlueImageProcessor with matplotlib soft dependency * fix: added missing import error for matplotlib * Updated convert script to push on ETH org * fix: added missing licence * fix: make fix-copies * refactor: use cohere apply_rotary_pos_emb function * fix: update model references to use ETH-CVG/lightglue_superpoint * refactor: add and use intermediate_size attribute in config to inherit CLIPMLP for LightGlueMLP * refactor: explicit variables instead of slicing * refactor: use can_return_tuple decorator in LightGlue model * fix: make fix-copies * docs: Update model references in `lightglue.md` to use the correct pretrained model from ETH-CVG * Refactor LightGlue configuration and processing classes - Updated type hints for `keypoint_detector_config` in `LightGlueConfig` to use `SuperPointConfig` directly. - Changed `size` parameter in `LightGlueImageProcessor` to be optional. - Modified `position_embeddings` in `LightGlueAttention` and `LightGlueAttentionBlock` to be optional tuples. - Cleaned up import statements across multiple files for better readability and consistency. * refactor: Update LightGlue configuration to enforce eager attention implementation - Added `attn_implementation="eager"` to `keypoint_detector_config` in `LightGlueConfig` and `LightGlueAttention` classes. - Removed unnecessary logging related to attention implementation fallback. - Cleaned up import statements for better readability. * refactor: renamed message into attention_output * fix: ensure device compatibility in LightGlueMatchAssignmentLayer descriptor normalization - Updated the normalization of `m_descriptors` to use the correct device for the tensor, ensuring compatibility across different hardware setups. * refactor: removed Conv layers from init_weights since LightGlue doesn't have any * refactor: replace add_start_docstrings with auto_docstring in LightGlue models - Updated LightGlue model classes to utilize the new auto_docstring utility for automatic documentation generation. - Removed legacy docstring handling to streamline the code and improve maintainability. * refactor: simplify LightGlue image processing tests by inheriting from SuperGlue - Refactored `LightGlueImageProcessingTester` and `LightGlueImageProcessingTest` to inherit from their SuperGlue counterparts, reducing code duplication. - Removed redundant methods and properties, streamlining the test setup and improving maintainability. * test: forced eager attention implementation to LightGlue model tests - Updated `LightGlueModelTester` to include `attn_implementation="eager"` in the model configuration. - This change aligns the test setup with the recent updates in LightGlue configuration for eager attention. * refactor: update LightGlue model references * fix: import error * test: enhance LightGlue image processing tests with setup method - Added a setup method in `LightGlueImageProcessingTest` to initialize `LightGlueImageProcessingTester`. - Included a docstring for `LightGlueImageProcessingTester` to clarify its purpose. * refactor: added LightGlue image processing implementation to modular file * refactor: moved attention blocks into the transformer layer * fix: added missing import * fix: added missing import in __all__ variable * doc: added comment about enforcing eager attention because of SuperPoint * refactor: added SuperPoint eager attention comment and moved functions to the closest they are used --------- Co-authored-by: Steven Bucaille Co-authored-by: Pavel Iakubovskii --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/lightglue.md | 104 ++ src/transformers/__init__.py | 2 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/lightglue/__init__.py | 28 + .../lightglue/configuration_lightglue.py | 143 +++ .../lightglue/convert_lightglue_to_hf.py | 281 +++++ .../lightglue/image_processing_lightglue.py | 452 ++++++++ .../models/lightglue/modeling_lightglue.py | 926 +++++++++++++++ .../models/lightglue/modular_lightglue.py | 1000 +++++++++++++++++ .../superglue/image_processing_superglue.py | 3 +- .../models/superpoint/modeling_superpoint.py | 2 +- src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + tests/models/lightglue/__init__.py | 0 .../test_image_processing_lightglue.py | 96 ++ .../lightglue/test_modeling_lightglue.py | 584 ++++++++++ 20 files changed, 3632 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/model_doc/lightglue.md create mode 100644 src/transformers/models/lightglue/__init__.py create mode 100644 src/transformers/models/lightglue/configuration_lightglue.py create mode 100644 src/transformers/models/lightglue/convert_lightglue_to_hf.py create mode 100644 src/transformers/models/lightglue/image_processing_lightglue.py create mode 100644 src/transformers/models/lightglue/modeling_lightglue.py create mode 100644 src/transformers/models/lightglue/modular_lightglue.py create mode 100644 tests/models/lightglue/__init__.py create mode 100644 tests/models/lightglue/test_image_processing_lightglue.py create mode 100644 tests/models/lightglue/test_modeling_lightglue.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7f9dbaea05..fd9b69ebc1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -743,6 +743,8 @@ title: ImageGPT - local: model_doc/levit title: LeViT + - local: model_doc/lightglue + title: LightGlue - local: model_doc/mask2former title: Mask2Former - local: model_doc/maskformer diff --git a/docs/source/en/model_doc/lightglue.md b/docs/source/en/model_doc/lightglue.md new file mode 100644 index 0000000000..3d9403c420 --- /dev/null +++ b/docs/source/en/model_doc/lightglue.md @@ -0,0 +1,104 @@ + + +# LightGlue + +## Overview + +The LightGlue model was proposed in [LightGlue: Local Feature Matching at Light Speed](https://arxiv.org/abs/2306.13643) +by Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. + +Similar to [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor), this model consists of matching +two sets of local features extracted from two images, its goal is to be faster than SuperGlue. Paired with the +[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and +estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc. + +The abstract from the paper is the following: + +*We introduce LightGlue, a deep neural network that learns to match local features across images. We revisit multiple +design decisions of SuperGlue, the state of the art in sparse matching, and derive simple but effective improvements. +Cumulatively, they make LightGlue more efficient - in terms of both memory and computation, more accurate, and much +easier to train. One key property is that LightGlue is adaptive to the difficulty of the problem: the inference is much +faster on image pairs that are intuitively easy to match, for example because of a larger visual overlap or limited +appearance change. This opens up exciting prospects for deploying deep matchers in latency-sensitive applications like +3D reconstruction. The code and trained models are publicly available at this [https URL](https://github.com/cvg/LightGlue)* + +## How to use + +Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched. +The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding +matching scores. +```python +from transformers import AutoImageProcessor, AutoModel +import torch +from PIL import Image +import requests + +url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg" +image1 = Image.open(requests.get(url_image1, stream=True).raw) +url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg" +image2 = Image.open(requests.get(url_image2, stream=True).raw) + +images = [image1, image2] + +processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint") +model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint") + +inputs = processor(images, return_tensors="pt") +with torch.no_grad(): + outputs = model(**inputs) +``` + +You can use the `post_process_keypoint_matching` method from the `LightGlueImageProcessor` to get the keypoints and matches in a readable format: +```python +image_sizes = [[(image.height, image.width) for image in images]] +outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2) +for i, output in enumerate(outputs): + print("For the image pair", i) + for keypoint0, keypoint1, matching_score in zip( + output["keypoints0"], output["keypoints1"], output["matching_scores"] + ): + print( + f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}." + ) +``` + +You can visualize the matches between the images by providing the original images as well as the outputs to this method: +```python +processor.plot_keypoint_matching(images, outputs) +``` + +![image/png](https://cdn-uploads.huggingface.co/production/uploads/632885ba1558dac67c440aa8/duPp09ty8NRZlMZS18ccP.png) + +This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille). +The original code can be found [here](https://github.com/cvg/LightGlue). + +## LightGlueConfig + +[[autodoc]] LightGlueConfig + +## LightGlueImageProcessor + +[[autodoc]] LightGlueImageProcessor + +- preprocess +- post_process_keypoint_matching +- plot_keypoint_matching + +## LightGlueForKeypointMatching + +[[autodoc]] LightGlueForKeypointMatching + +- forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 155d0fd6d3..5a277749f2 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -231,6 +231,7 @@ _import_structure = { "is_faiss_available", "is_flax_available", "is_keras_nlp_available", + "is_matplotlib_available", "is_phonemizer_available", "is_psutil_available", "is_py3nvml_available", @@ -728,6 +729,7 @@ if TYPE_CHECKING: is_faiss_available, is_flax_available, is_keras_nlp_available, + is_matplotlib_available, is_phonemizer_available, is_psutil_available, is_py3nvml_available, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index dea3c98c38..3520b79d69 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -162,6 +162,7 @@ if TYPE_CHECKING: from .layoutxlm import * from .led import * from .levit import * + from .lightglue import * from .lilt import * from .llama import * from .llama4 import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index fe8a889b5d..5b46868f64 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -185,6 +185,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str]( ("layoutlmv3", "LayoutLMv3Config"), ("led", "LEDConfig"), ("levit", "LevitConfig"), + ("lightglue", "LightGlueConfig"), ("lilt", "LiltConfig"), ("llama", "LlamaConfig"), ("llama4", "Llama4Config"), @@ -556,6 +557,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str]( ("layoutxlm", "LayoutXLM"), ("led", "LED"), ("levit", "LeViT"), + ("lightglue", "LightGlue"), ("lilt", "LiLT"), ("llama", "LLaMA"), ("llama2", "Llama2"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 2faabea5fe..5b71802465 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -106,6 +106,7 @@ else: ("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")), ("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")), ("levit", ("LevitImageProcessor", "LevitImageProcessorFast")), + ("lightglue", ("LightGlueImageProcessor",)), ("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")), ("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")), ("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 1b776b66ce..fbd0adfe4b 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -175,6 +175,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ("layoutlmv3", "LayoutLMv3Model"), ("led", "LEDModel"), ("levit", "LevitModel"), + ("lightglue", "LightGlueForKeypointMatching"), ("lilt", "LiltModel"), ("llama", "LlamaModel"), ("llama4", "Llama4ForConditionalGeneration"), diff --git a/src/transformers/models/lightglue/__init__.py b/src/transformers/models/lightglue/__init__.py new file mode 100644 index 0000000000..190e4e4329 --- /dev/null +++ b/src/transformers/models/lightglue/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_lightglue import * + from .image_processing_lightglue import * + from .modeling_lightglue import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/lightglue/configuration_lightglue.py b/src/transformers/models/lightglue/configuration_lightglue.py new file mode 100644 index 0000000000..f0962c0cc7 --- /dev/null +++ b/src/transformers/models/lightglue/configuration_lightglue.py @@ -0,0 +1,143 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig +from ..superpoint import SuperPointConfig + + +class LightGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to + instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LightGlue + [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + descriptor_dim (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + num_hidden_layers (`int`, *optional*, defaults to 9): + The number of self and cross attention layers. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the multi-head attention. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + depth_confidence (`float`, *optional*, defaults to 0.95): + The confidence threshold used to perform early stopping + width_confidence (`float`, *optional*, defaults to 0.99): + The confidence threshold used to prune points + filter_threshold (`float`, *optional*, defaults to 0.1): + The confidence threshold used to filter matches + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function to be used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + + Examples: + ```python + >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching + + >>> # Initializing a LightGlue style configuration + >>> configuration = LightGlueConfig() + + >>> # Initializing a model from the LightGlue style configuration + >>> model = LightGlueForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightglue" + sub_configs = {"keypoint_detector_config": AutoConfig} + + def __init__( + self, + keypoint_detector_config: SuperPointConfig = None, + descriptor_dim: int = 256, + num_hidden_layers: int = 9, + num_attention_heads: int = 4, + num_key_value_heads=None, + depth_confidence: float = 0.95, + width_confidence: float = 0.99, + filter_threshold: float = 0.1, + initializer_range: float = 0.02, + hidden_act: str = "gelu", + attention_dropout=0.0, + attention_bias=True, + **kwargs, + ): + if descriptor_dim % num_attention_heads != 0: + raise ValueError("descriptor_dim % num_heads is different from zero") + + self.descriptor_dim = descriptor_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.initializer_range = initializer_range + + # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention + # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = ( + keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" + ) + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config, attn_implementation="eager" + ) + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") + + self.keypoint_detector_config = keypoint_detector_config + + self.hidden_size = descriptor_dim + self.intermediate_size = descriptor_dim * 2 + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + super().__init__(**kwargs) + + +__all__ = ["LightGlueConfig"] diff --git a/src/transformers/models/lightglue/convert_lightglue_to_hf.py b/src/transformers/models/lightglue/convert_lightglue_to_hf.py new file mode 100644 index 0000000000..c1cb2ce587 --- /dev/null +++ b/src/transformers/models/lightglue/convert_lightglue_to_hf.py @@ -0,0 +1,281 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import os +import re +from typing import List + +import torch +from datasets import load_dataset + +from transformers import ( + AutoModelForKeypointDetection, + LightGlueForKeypointMatching, + LightGlueImageProcessor, +) +from transformers.models.lightglue.configuration_lightglue import LightGlueConfig + + +DEFAULT_CHECKPOINT_URL = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_lightglue.pth" + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image0 = dataset[0]["image"] + image1 = dataset[1]["image"] + image2 = dataset[2]["image"] + # [image1, image1] on purpose to test the model early stopping + return [[image2, image0], [image1, image1]] + + +def verify_model_outputs(model, device): + images = prepare_imgs() + preprocessor = LightGlueImageProcessor() + inputs = preprocessor(images=images, return_tensors="pt").to(device) + model.to(device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_matches_values = outputs.matches[0, 0, 20:30] + predicted_matching_scores_values = outputs.matching_scores[0, 0, 20:30] + + predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item() + + expected_max_number_keypoints = 866 + expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints)) + + expected_matches_values = torch.tensor([-1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64).to(device) + expected_matching_scores_values = torch.tensor([0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583]).to(device) + + expected_number_of_matches = 140 + + assert outputs.matches.shape == expected_matches_shape + assert outputs.matching_scores.shape == expected_matching_scores_shape + + assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-2) + assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2) + + assert predicted_number_of_matches == expected_number_of_matches + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"posenc.Wr": r"positional_encoder.projector", + r"self_attn.(\d+).Wqkv": r"transformer_layers.\1.self_attention.Wqkv", + r"self_attn.(\d+).out_proj": r"transformer_layers.\1.self_attention.o_proj", + r"self_attn.(\d+).ffn.0": r"transformer_layers.\1.self_mlp.fc1", + r"self_attn.(\d+).ffn.1": r"transformer_layers.\1.self_mlp.layer_norm", + r"self_attn.(\d+).ffn.3": r"transformer_layers.\1.self_mlp.fc2", + r"cross_attn.(\d+).to_qk": r"transformer_layers.\1.cross_attention.to_qk", + r"cross_attn.(\d+).to_v": r"transformer_layers.\1.cross_attention.v_proj", + r"cross_attn.(\d+).to_out": r"transformer_layers.\1.cross_attention.o_proj", + r"cross_attn.(\d+).ffn.0": r"transformer_layers.\1.cross_mlp.fc1", + r"cross_attn.(\d+).ffn.1": r"transformer_layers.\1.cross_mlp.layer_norm", + r"cross_attn.(\d+).ffn.3": r"transformer_layers.\1.cross_mlp.fc2", + r"log_assignment.(\d+).matchability": r"match_assignment_layers.\1.matchability", + r"log_assignment.(\d+).final_proj": r"match_assignment_layers.\1.final_projection", + r"token_confidence.(\d+).token.0": r"token_confidence.\1.token", +} + + +def convert_old_keys_to_new_keys(state_dict_keys: List[str]): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +def add_keypoint_detector_state_dict(lightglue_state_dict): + keypoint_detector = AutoModelForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + keypoint_detector_state_dict = keypoint_detector.state_dict() + for k, v in keypoint_detector_state_dict.items(): + lightglue_state_dict[f"keypoint_detector.{k}"] = v + return lightglue_state_dict + + +def split_weights(state_dict): + for i in range(9): + # Remove unused r values + log_assignment_r_key = f"log_assignment.{i}.r" + if state_dict.get(log_assignment_r_key, None) is not None: + state_dict.pop(log_assignment_r_key) + + Wqkv_weight = state_dict.pop(f"transformer_layers.{i}.self_attention.Wqkv.weight") + Wqkv_bias = state_dict.pop(f"transformer_layers.{i}.self_attention.Wqkv.bias") + Wqkv_weight = Wqkv_weight.reshape(256, 3, 256) + Wqkv_bias = Wqkv_bias.reshape(256, 3) + query_weight, key_weight, value_weight = Wqkv_weight[:, 0], Wqkv_weight[:, 1], Wqkv_weight[:, 2] + query_bias, key_bias, value_bias = Wqkv_bias[:, 0], Wqkv_bias[:, 1], Wqkv_bias[:, 2] + state_dict[f"transformer_layers.{i}.self_attention.q_proj.weight"] = query_weight + state_dict[f"transformer_layers.{i}.self_attention.k_proj.weight"] = key_weight + state_dict[f"transformer_layers.{i}.self_attention.v_proj.weight"] = value_weight + state_dict[f"transformer_layers.{i}.self_attention.q_proj.bias"] = query_bias + state_dict[f"transformer_layers.{i}.self_attention.k_proj.bias"] = key_bias + state_dict[f"transformer_layers.{i}.self_attention.v_proj.bias"] = value_bias + + to_qk_weight = state_dict.pop(f"transformer_layers.{i}.cross_attention.to_qk.weight") + to_qk_bias = state_dict.pop(f"transformer_layers.{i}.cross_attention.to_qk.bias") + state_dict[f"transformer_layers.{i}.cross_attention.q_proj.weight"] = to_qk_weight + state_dict[f"transformer_layers.{i}.cross_attention.q_proj.bias"] = to_qk_bias + state_dict[f"transformer_layers.{i}.cross_attention.k_proj.weight"] = to_qk_weight + state_dict[f"transformer_layers.{i}.cross_attention.k_proj.bias"] = to_qk_bias + + return state_dict + + +@torch.no_grad() +def write_model( + model_path, + checkpoint_url, + organization, + safe_serialization=True, + push_to_hub=False, +): + os.makedirs(model_path, exist_ok=True) + + # ------------------------------------------------------------ + # LightGlue config + # ------------------------------------------------------------ + + config = LightGlueConfig( + descriptor_dim=256, + num_hidden_layers=9, + num_attention_heads=4, + ) + config.architectures = ["LightGlueForKeypointMatching"] + config.save_pretrained(model_path) + print("Model config saved successfully...") + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + print(f"Fetching all parameters from the checkpoint at {checkpoint_url}...") + original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url) + + print("Converting model...") + all_keys = list(original_state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + state_dict[new_key] = original_state_dict.pop(key).contiguous().clone() + + del original_state_dict + gc.collect() + state_dict = split_weights(state_dict) + state_dict = add_keypoint_detector_state_dict(state_dict) + + print("Loading the checkpoint in a LightGlue model...") + device = "cuda" + with torch.device(device): + model = LightGlueForKeypointMatching(config) + model.load_state_dict(state_dict) + print("Checkpoint loaded successfully...") + del model.config._name_or_path + + print("Saving the model...") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + del state_dict, model + + # Safety check: reload the converted model + gc.collect() + print("Reloading the model to check if it's saved correctly.") + model = LightGlueForKeypointMatching.from_pretrained(model_path) + print("Model reloaded successfully.") + + model_name = "lightglue" + if "superpoint" in checkpoint_url: + model_name += "_superpoint" + if checkpoint_url == DEFAULT_CHECKPOINT_URL: + print("Checking the model outputs...") + verify_model_outputs(model, device) + print("Model outputs verified successfully.") + + if push_to_hub: + print("Pushing model to the hub...") + model.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add model", + ) + config.push_to_hub(repo_id=f"{organization}/{model_name}", commit_message="Add config") + + write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub) + + +def write_image_processor(save_dir, model_name, organization, push_to_hub=False): + if "superpoint" in model_name: + image_processor = LightGlueImageProcessor(do_grayscale=True) + else: + image_processor = LightGlueImageProcessor() + image_processor.save_pretrained(save_dir) + + if push_to_hub: + print("Pushing image processor to the hub...") + image_processor.push_to_hub( + repo_id=f"{organization}/{model_name}", + commit_message="Add image processor", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default=DEFAULT_CHECKPOINT_URL, + type=str, + help="URL of the original LightGlue checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + required=True, + help="Path to the output PyTorch model directory.", + ) + parser.add_argument("--save_model", action="store_true", help="Save model to local") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push model and image preprocessor to the hub", + ) + parser.add_argument( + "--organization", + default="ETH-CVG", + type=str, + help="Hub organization in which you want the model to be uploaded.", + ) + + args = parser.parse_args() + write_model( + args.pytorch_dump_folder_path, + args.checkpoint_url, + args.organization, + safe_serialization=True, + push_to_hub=args.push_to_hub, + ) diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py new file mode 100644 index 0000000000..fea0b32df3 --- /dev/null +++ b/src/transformers/models/lightglue/image_processing_lightglue.py @@ -0,0 +1,452 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import resize, to_channel_dimension_format +from ...image_utils import ( + ChannelDimension, + ImageInput, + ImageType, + PILImageResampling, + get_image_type, + infer_channel_dimension_format, + is_pil_image, + is_scaled_image, + is_valid_image, + is_vision_available, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, is_matplotlib_available, logging, requires_backends +from ...utils.import_utils import requires +from .modeling_lightglue import LightGlueKeypointMatchingOutput + + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +def is_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +): + if input_data_format == ChannelDimension.FIRST: + if image.shape[0] == 1: + return True + return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]) + elif input_data_format == ChannelDimension.LAST: + if image.shape[-1] == 1: + return True + return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2]) + + +def convert_to_grayscale( + image: ImageInput, + input_data_format: Optional[Union[str, ChannelDimension]] = None, +) -> ImageInput: + """ + Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch + and tensorflow grayscale conversion + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (Image): + The image to convert. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. + """ + requires_backends(convert_to_grayscale, ["vision"]) + + if isinstance(image, np.ndarray): + if is_grayscale(image, input_data_format=input_data_format): + return image + if input_data_format == ChannelDimension.FIRST: + gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=0) + elif input_data_format == ChannelDimension.LAST: + gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140 + gray_image = np.stack([gray_image] * 3, axis=-1) + return gray_image + + if not isinstance(image, PIL.Image.Image): + return image + + image = image.convert("L") + return image + + +def validate_and_format_image_pairs(images: ImageInput): + error_message = ( + "Input images must be a one of the following :", + " - A pair of PIL images.", + " - A pair of 3D arrays.", + " - A list of pairs of PIL images.", + " - A list of pairs of 3D arrays.", + ) + + def _is_valid_image(image): + """images is a PIL Image or a 3D array.""" + return is_pil_image(image) or ( + is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3 + ) + + if isinstance(images, list): + if len(images) == 2 and all((_is_valid_image(image)) for image in images): + return images + if all( + isinstance(image_pair, list) + and len(image_pair) == 2 + and all(_is_valid_image(image) for image in image_pair) + for image_pair in images + ): + return [image for image_pair in images for image in image_pair] + raise ValueError(error_message) + + +@requires(backends=("torch",)) +class LightGlueImageProcessor(BaseImageProcessor): + r""" + Constructs a LightGlue image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden + by `do_resize` in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): + Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to + `True`. Can be overridden by `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in + the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess` + method. + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_grayscale: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 480, "width": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_grayscale = do_grayscale + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the output image. If not provided, it will be inferred from the input + image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + size = get_size_dict(size, default_to_square=False) + + return resize( + image, + size=(size["height"], size["width"]), + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + def preprocess( + self, + images, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_grayscale: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> BatchFeature: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with + pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set + `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image + is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the + image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to + `(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`): + Whether to convert the image to grayscale. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + + # Validate and convert the input images into a flattened list of images for all subsequent processing steps. + images = validate_and_format_image_pairs(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + all_images = [] + for image in images: + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_grayscale: + image = convert_to_grayscale(image, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + all_images.append(image) + + # Convert back the flattened list of images into a list of pairs of images. + image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)] + + data = {"pixel_values": image_pairs} + + return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_keypoint_matching( + self, + outputs: LightGlueKeypointMatchingOutput, + target_sizes: Union[TensorType, List[Tuple]], + threshold: float = 0.0, + ) -> List[Dict[str, torch.Tensor]]: + """ + Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + Args: + outputs ([`KeypointMatchingOutput`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*): + Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the + target size `(height, width)` of each image in the batch. This must be the original image size (before + any processing). + threshold (`float`, *optional*, defaults to 0.0): + Threshold to filter out the matches with low scores. + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image + of the pair, the matching scores and the matching indices. + """ + if outputs.mask.shape[0] != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + if not all(len(target_size) == 2 for target_size in target_sizes): + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + if isinstance(target_sizes, List): + image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_pair_sizes = target_sizes + + keypoints = outputs.keypoints.clone() + keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2) + keypoints = keypoints.to(torch.int32) + + results = [] + for mask_pair, keypoints_pair, matches, scores in zip( + outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0] + ): + mask0 = mask_pair[0] > 0 + mask1 = mask_pair[1] > 0 + keypoints0 = keypoints_pair[0][mask0] + keypoints1 = keypoints_pair[1][mask1] + matches0 = matches[mask0] + scores0 = scores[mask0] + + # Filter out matches with low scores + valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1) + + matched_keypoints0 = keypoints0[valid_matches] + matched_keypoints1 = keypoints1[matches0[valid_matches]] + matching_scores = scores0[valid_matches] + + results.append( + { + "keypoints0": matched_keypoints0, + "keypoints1": matched_keypoints1, + "matching_scores": matching_scores, + } + ) + + return results + + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires + matplotlib to be installed. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or + a list of list of 2 images list with pixel values ranging from 0 to 255. + outputs ([`LightGlueKeypointMatchingOutput`]): + Raw outputs of the model. + """ + if is_matplotlib_available(): + import matplotlib.pyplot as plt + else: + raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") + + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) + plot_image[:height0, :width0] = image_pair[0] / 255.0 + plot_image[:height1, width0:] = image_pair[1] / 255.0 + plt.imshow(plot_image) + plt.axis("off") + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + plt.plot( + [keypoint0_x, keypoint1_x + width0], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) + plt.show() + + +__all__ = ["LightGlueImageProcessor"] diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py new file mode 100644 index 0000000000..2cd8b0732f --- /dev/null +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -0,0 +1,926 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_lightglue.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...activations import ACT2FN +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, auto_docstring +from ...utils.generic import can_return_tuple +from ..auto.modeling_auto import AutoModelForKeypointDetection +from .configuration_lightglue import LightGlueConfig + + +@dataclass +class LightGlueKeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, + the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the + batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask + tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint + matching information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Pruning mask indicating which keypoints are removed and at which layer. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching + information. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)` returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True` + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)` returned when `output_attentions=True` is passed or when + `config.output_attentions=True` + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + prune: Optional[torch.IntTensor] = None + mask: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class LightGluePositionalEncoder(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) + + def forward( + self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + projected_keypoints = self.projector(keypoints) + embeddings = projected_keypoints.repeat_interleave(2, dim=-1) + cosines = torch.cos(embeddings) + sines = torch.sin(embeddings) + embeddings = (cosines, sines) + output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) + return output + + +def rotate_half(x): + # Split and rotate. Note that this function is different from e.g. Llama. + x1 = x[..., ::2] + x2 = x[..., 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + return rot_x + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + dtype = q.dtype + q = q.float() + k = k.float() + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class LightGlueAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightGlueMLP(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class LightGlueTransformerLayer(nn.Module): + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.self_attention = LightGlueAttention(config, layer_idx) + self.self_mlp = LightGlueMLP(config) + self.cross_attention = LightGlueAttention(config, layer_idx) + self.cross_mlp = LightGlueMLP(config) + + def forward( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + batch_size, num_keypoints, descriptor_dim = descriptors.shape + + # Self attention block + attention_output, self_attentions = self.self_attention( + descriptors, + position_embeddings=keypoints, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + intermediate_states = torch.cat([descriptors, attention_output], dim=-1) + output_states = self.self_mlp(intermediate_states) + self_attention_descriptors = descriptors + output_states + + if output_hidden_states: + self_attention_hidden_states = (intermediate_states, output_states) + + # Reshape hidden_states to group by image_pairs : + # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + # Flip dimension 1 to perform cross attention : + # (image0, image1) -> (image1, image0) + # Reshape back to original shape : + # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) + encoder_hidden_states = ( + self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) + .flip(1) + .reshape(batch_size, num_keypoints, descriptor_dim) + ) + # Same for mask + encoder_attention_mask = ( + attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if attention_mask is not None + else None + ) + + # Cross attention block + cross_attention_output, cross_attentions = self.cross_attention( + self_attention_descriptors, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) + cross_output_states = self.cross_mlp(cross_intermediate_states) + descriptors = self_attention_descriptors + cross_output_states + + if output_hidden_states: + cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) + all_hidden_states = ( + all_hidden_states + + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + self_attention_hidden_states + + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + cross_attention_hidden_states + ) + + if output_attentions: + all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) + + return descriptors, all_hidden_states, all_attentions + + +def sigmoid_log_double_softmax( + similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape + certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) + scores0 = nn.functional.log_softmax(similarity, 2) + scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) + scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties + scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) + scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) + return scores + + +class LightGlueMatchAssignmentLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.descriptor_dim = config.descriptor_dim + self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) + self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) + + def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + batch_size, num_keypoints, descriptor_dim = descriptors.shape + # Final projection and similarity computation + m_descriptors = self.final_projection(descriptors) + m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 + m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) + m_descriptors0 = m_descriptors[:, 0] + m_descriptors1 = m_descriptors[:, 1] + similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) + if mask is not None: + mask = mask.reshape(batch_size // 2, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1) + mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) + mask = mask0 * mask1 + similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) + + # Compute matchability of descriptors + matchability = self.matchability(descriptors) + matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) + matchability_0 = matchability[:, 0] + matchability_1 = matchability[:, 1] + + # Compute scores from similarity and matchability + scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) + return scores + + def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: + """Get matchability of descriptors as a probability""" + matchability = self.matchability(descriptors) + matchability = nn.functional.sigmoid(matchability).squeeze(-1) + return matchability + + +class LightGlueTokenConfidenceLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.token = nn.Linear(config.descriptor_dim, 1) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + token = self.token(descriptors.detach()) + token = nn.functional.sigmoid(token).squeeze(-1) + return token + + +@auto_docstring +class LightGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LightGlueConfig + base_model_prefix = "lightglue" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: + """obtain matches from a score matrix [Bx M+1 x N+1]""" + batch_size, _, _ = scores.shape + # For each keypoint, get the best match + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + matches0 = max0.indices + matches1 = max1.indices + + # Mutual check for matches + indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] + indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] + mutual0 = indices0 == matches1.gather(1, matches0) + mutual1 = indices1 == matches0.gather(1, matches1) + + # Get matching scores and filter based on mutual check and thresholding + max0 = max0.values.exp() + zero = max0.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) + valid0 = mutual0 & (matching_scores0 > threshold) + valid1 = mutual1 & valid0.gather(1, matches1) + + # Filter matches based on mutual check and thresholding of scores + matches0 = torch.where(valid0, matches0, -1) + matches1 = torch.where(valid1, matches1, -1) + matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) + matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) + + return matches, matching_scores + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + shift = size / 2 + scale = size.max(-1).values / 2 + keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] + return keypoints + + +@auto_docstring( + custom_intro=""" + LightGlue model taking images as inputs and outputting the matching of them. + """ +) +class LightGlueForKeypointMatching(LightGluePreTrainedModel): + """ + LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as + SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. + It consists of : + 1. Keypoint Encoder + 2. A Graph Neural Network with self and cross attention layers + 3. Matching Assignment layers + + The correspondence ids use -1 to indicate non-matching points. + + Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. + In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf + """ + + def __init__(self, config: LightGlueConfig): + super().__init__(config) + + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) + + self.descriptor_dim = config.descriptor_dim + self.num_layers = config.num_hidden_layers + self.filter_threshold = config.filter_threshold + self.depth_confidence = config.depth_confidence + self.width_confidence = config.width_confidence + + if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim: + self.input_projection = nn.Linear( + config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True + ) + else: + self.input_projection = nn.Identity() + + self.positional_encoder = LightGluePositionalEncoder(config) + + self.transformer_layers = nn.ModuleList( + [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.match_assignment_layers = nn.ModuleList( + [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.token_confidence = nn.ModuleList( + [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def _get_confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold for a given layer""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) + return np.clip(threshold, 0, 1) + + def _keypoint_processing( + self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + descriptors = descriptors.detach().contiguous() + projected_descriptors = self.input_projection(descriptors) + keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) + return projected_descriptors, keypoint_encoding_output + + def _get_early_stopped_image_pairs( + self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor + ) -> torch.Tensor: + """evaluate whether we should stop inference based on the confidence of the keypoints""" + batch_size, _ = mask.shape + if layer_index < self.num_layers - 1: + # If the current layer is not the last layer, we compute the confidence of the keypoints and check + # if we should stop the forward pass through the transformer layers for each pair of images. + keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) + keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) + threshold = self._get_confidence_threshold(layer_index) + ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points + early_stopped_pairs = ratio_confident > self.depth_confidence + else: + # If the current layer is the last layer, we stop the forward pass through the transformer layers for + # all pairs of images. + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + return early_stopped_pairs + + def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): + if early_stops is not None: + descriptors = descriptors[early_stops] + mask = mask[early_stops] + scores = self.match_assignment_layers[layer_index](descriptors, mask) + matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) + return matches, matching_scores + + def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self._get_confidence_threshold(layer_index) + return keep + + def _do_layer_keypoint_pruning( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + mask: torch.Tensor, + indices: torch.Tensor, + prune_output: torch.Tensor, + keypoint_confidences: torch.Tensor, + layer_index: int, + ): + """ + For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the + descriptors. + """ + batch_size, _, _ = descriptors.shape + descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) + pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) + pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) + + # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( + [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] + for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] + ) + for i in range(batch_size): + prune_output[i, pruned_indices[i]] += 1 + + # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( + pad_sequence(pruned_tensor, batch_first=True) + for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] + ) + pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) + pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) + + return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output + + def _concat_early_stopped_outputs( + self, + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ): + early_stops_indices = torch.stack(early_stops_indices) + matches, final_pruned_keypoints_indices = ( + pad_sequence(tensor, batch_first=True, padding_value=-1) + for tensor in [matches, final_pruned_keypoints_indices] + ) + matching_scores, final_pruned_keypoints_iterations = ( + pad_sequence(tensor, batch_first=True, padding_value=0) + for tensor in [matching_scores, final_pruned_keypoints_iterations] + ) + matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( + tensor[early_stops_indices] + for tensor in [ + matches, + matching_scores, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + ] + ) + return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores + + def _do_final_keypoint_pruning( + self, + indices: torch.Tensor, + matches: torch.Tensor, + matching_scores: torch.Tensor, + num_keypoints: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to + # have tensors from + batch_size, _ = indices.shape + indices, matches, matching_scores = ( + tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] + ) + indices0 = indices[:, 0] + indices1 = indices[:, 1] + matches0 = matches[:, 0] + matches1 = matches[:, 1] + matching_scores0 = matching_scores[:, 0] + matching_scores1 = matching_scores[:, 1] + + # Prepare final matches and matching scores + _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) + _matching_scores = torch.zeros( + (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype + ) + # Fill the matches and matching scores for each image pair + for i in range(batch_size // 2): + _matches[i, 0, indices0[i]] = torch.where( + matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) + ) + _matches[i, 1, indices1[i]] = torch.where( + matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) + ) + _matching_scores[i, 0, indices0[i]] = matching_scores0[i] + _matching_scores[i, 1, indices1[i]] = matching_scores1[i] + return _matches, _matching_scores + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + height: int, + width: int, + mask: torch.Tensor = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + device = keypoints.device + batch_size, _, initial_num_keypoints, _ = keypoints.shape + num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) + mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None + descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim) + image_indices = torch.arange(batch_size * 2, device=device) + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + descriptors, keypoint_encoding_output = self._keypoint_processing( + descriptors, keypoints, output_hidden_states=output_hidden_states + ) + + keypoints = keypoint_encoding_output[0] + + # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the + # keypoints is above a certain threshold. + do_early_stop = self.depth_confidence > 0 + # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of + # the keypoints is below a certain threshold. + do_keypoint_pruning = self.width_confidence > 0 + + early_stops_indices = [] + matches = [] + matching_scores = [] + final_pruned_keypoints_indices = [] + final_pruned_keypoints_iterations = [] + + pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) + pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) + + for layer_index in range(self.num_layers): + input_shape = descriptors.size() + if mask is not None: + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) + layer_output = self.transformer_layers[layer_index]( + descriptors, + keypoints, + attention_mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors, hidden_states, attention = layer_output + if output_hidden_states: + all_hidden_states = all_hidden_states + hidden_states + if output_attentions: + all_attentions = all_attentions + attention + + if do_early_stop: + if layer_index < self.num_layers - 1: + # Get the confidence of the keypoints for the current layer + keypoint_confidences = self.token_confidence[layer_index](descriptors) + + # Determine which pairs of images should be early stopped based on the confidence of the keypoints for + # the current layer. + early_stopped_pairs = self._get_early_stopped_image_pairs( + keypoint_confidences, layer_index, mask, num_points=num_points_per_pair + ) + else: + # Early stopping always occurs at the last layer + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + + if torch.any(early_stopped_pairs): + # If a pair of images is considered early stopped, we compute the matches for the remaining + # keypoints and stop the forward pass through the transformer layers for this pair of images. + early_stops = early_stopped_pairs.repeat_interleave(2) + early_stopped_image_indices = image_indices[early_stops] + early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( + descriptors, mask, layer_index, early_stops=early_stops + ) + early_stops_indices.extend(list(early_stopped_image_indices)) + matches.extend(list(early_stopped_matches)) + matching_scores.extend(list(early_stopped_matching_scores)) + if do_keypoint_pruning: + final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) + final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) + + # Remove image pairs that have been early stopped from the forward pass + num_points_per_pair = num_points_per_pair[~early_stopped_pairs] + descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( + ( + tensor[~early_stops] + for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] + ) + ) + keypoints = (keypoints_0, keypoint_1) + if do_keypoint_pruning: + pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( + ( + tensor[~early_stops] + for tensor in [ + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + ] + ) + ) + # If all pairs of images are early stopped, we stop the forward pass through the transformer + # layers for all pairs of images. + if torch.all(early_stopped_pairs): + break + + if do_keypoint_pruning: + # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of + # the keypoints is below a certain threshold. + descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( + self._do_layer_keypoint_pruning( + descriptors, + keypoints, + mask, + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + layer_index, + ) + ) + + if do_early_stop and do_keypoint_pruning: + # Concatenate early stopped outputs together and perform final keypoint pruning + final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( + self._concat_early_stopped_outputs( + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ) + ) + matches, matching_scores = self._do_final_keypoint_pruning( + final_pruned_keypoints_indices, + matches, + matching_scores, + initial_num_keypoints, + ) + else: + matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) + final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers + + final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( + batch_size, 2, initial_num_keypoints + ) + + return ( + matches, + matching_scores, + final_pruned_keypoints_iterations, + all_hidden_states, + all_attentions, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, LightGlueKeypointMatchingOutput]: + loss = None + if labels is not None: + raise ValueError("LightGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, _, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return LightGlueKeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + prune=prune, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"] diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py new file mode 100644 index 0000000000..482c230fb8 --- /dev/null +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -0,0 +1,1000 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence + +from ...configuration_utils import PretrainedConfig +from ...image_utils import ImageInput, to_numpy_array +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging +from ...utils.generic import can_return_tuple +from ..auto import CONFIG_MAPPING, AutoConfig +from ..auto.modeling_auto import AutoModelForKeypointDetection +from ..clip.modeling_clip import CLIPMLP +from ..cohere.modeling_cohere import apply_rotary_pos_emb +from ..llama.modeling_llama import LlamaAttention, eager_attention_forward +from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs +from ..superpoint import SuperPointConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC_ = "LightGlueConfig" +_CHECKPOINT_FOR_DOC_ = "ETH-CVG/lightglue_superpoint" + + +class LightGlueConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to + instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the LightGlue + [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`): + The config object or dictionary of the keypoint detector. + descriptor_dim (`int`, *optional*, defaults to 256): + The dimension of the descriptors. + num_hidden_layers (`int`, *optional*, defaults to 9): + The number of self and cross attention layers. + num_attention_heads (`int`, *optional*, defaults to 4): + The number of heads in the multi-head attention. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + depth_confidence (`float`, *optional*, defaults to 0.95): + The confidence threshold used to perform early stopping + width_confidence (`float`, *optional*, defaults to 0.99): + The confidence threshold used to prune points + filter_threshold (`float`, *optional*, defaults to 0.1): + The confidence threshold used to filter matches + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function to be used in the hidden layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + attention_bias (`bool`, *optional*, defaults to `True`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + + Examples: + ```python + >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching + + >>> # Initializing a LightGlue style configuration + >>> configuration = LightGlueConfig() + + >>> # Initializing a model from the LightGlue style configuration + >>> model = LightGlueForKeypointMatching(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "lightglue" + sub_configs = {"keypoint_detector_config": AutoConfig} + + def __init__( + self, + keypoint_detector_config: SuperPointConfig = None, + descriptor_dim: int = 256, + num_hidden_layers: int = 9, + num_attention_heads: int = 4, + num_key_value_heads=None, + depth_confidence: float = 0.95, + width_confidence: float = 0.99, + filter_threshold: float = 0.1, + initializer_range: float = 0.02, + hidden_act: str = "gelu", + attention_dropout=0.0, + attention_bias=True, + **kwargs, + ): + if descriptor_dim % num_attention_heads != 0: + raise ValueError("descriptor_dim % num_heads is different from zero") + + self.descriptor_dim = descriptor_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.initializer_range = initializer_range + + # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention + # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153 + if isinstance(keypoint_detector_config, dict): + keypoint_detector_config["model_type"] = ( + keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint" + ) + keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]]( + **keypoint_detector_config, attn_implementation="eager" + ) + if keypoint_detector_config is None: + keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager") + + self.keypoint_detector_config = keypoint_detector_config + + self.hidden_size = descriptor_dim + self.intermediate_size = descriptor_dim * 2 + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + super().__init__(**kwargs) + + +@dataclass +class LightGlueKeypointMatchingOutput(ModelOutput): + """ + Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching, + the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the + batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask + tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint + matching information. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Loss computed during training. + matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Index of keypoint matched in the other image. + matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`): + Scores of predicted matches. + keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`): + Absolute (x, y) coordinates of predicted keypoints in a given image. + prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`): + Pruning mask indicating which keypoints are removed and at which layer. + mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`): + Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching + information. + hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels, + num_keypoints)` returned when `output_hidden_states=True` is passed or when + `config.output_hidden_states=True` + attentions (`Tuple[torch.FloatTensor, ...]`, *optional*): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints, + num_keypoints)` returned when `output_attentions=True` is passed or when + `config.output_attentions=True` + """ + + loss: Optional[torch.FloatTensor] = None + matches: Optional[torch.FloatTensor] = None + matching_scores: Optional[torch.FloatTensor] = None + keypoints: Optional[torch.FloatTensor] = None + prune: Optional[torch.IntTensor] = None + mask: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class LightGlueImageProcessor(SuperGlueImageProcessor): + def post_process_keypoint_matching( + self, + outputs: LightGlueKeypointMatchingOutput, + target_sizes: Union[TensorType, List[Tuple]], + threshold: float = 0.0, + ) -> List[Dict[str, torch.Tensor]]: + return super().post_process_keypoint_matching(outputs, target_sizes, threshold) + + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires + matplotlib to be installed. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or + a list of list of 2 images list with pixel values ranging from 0 to 255. + outputs ([`LightGlueKeypointMatchingOutput`]): + Raw outputs of the model. + """ + if is_matplotlib_available(): + import matplotlib.pyplot as plt + else: + raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method") + + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3)) + plot_image[:height0, :width0] = image_pair[0] / 255.0 + plot_image[:height1, width0:] = image_pair[1] / 255.0 + plt.imshow(plot_image) + plt.axis("off") + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + plt.plot( + [keypoint0_x, keypoint1_x + width0], + [keypoint0_y, keypoint1_y], + color=plt.get_cmap("RdYlGn")(matching_score.item()), + alpha=0.9, + linewidth=0.5, + ) + plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2) + plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2) + plt.show() + + +class LightGluePositionalEncoder(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False) + + def forward( + self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + projected_keypoints = self.projector(keypoints) + embeddings = projected_keypoints.repeat_interleave(2, dim=-1) + cosines = torch.cos(embeddings) + sines = torch.sin(embeddings) + embeddings = (cosines, sines) + output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,) + return output + + +class LightGlueAttention(LlamaAttention): + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + current_attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class LightGlueMLP(CLIPMLP): + def __init__(self, config: LightGlueConfig): + super().__init__(config) + self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size) + self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class LightGlueTransformerLayer(nn.Module): + def __init__(self, config: LightGlueConfig, layer_idx: int): + super().__init__() + self.self_attention = LightGlueAttention(config, layer_idx) + self.self_mlp = LightGlueMLP(config) + self.cross_attention = LightGlueAttention(config, layer_idx) + self.cross_mlp = LightGlueMLP(config) + + def forward( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + attention_mask: torch.Tensor, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if output_hidden_states: + all_hidden_states = all_hidden_states + (descriptors,) + + batch_size, num_keypoints, descriptor_dim = descriptors.shape + + # Self attention block + attention_output, self_attentions = self.self_attention( + descriptors, + position_embeddings=keypoints, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + intermediate_states = torch.cat([descriptors, attention_output], dim=-1) + output_states = self.self_mlp(intermediate_states) + self_attention_descriptors = descriptors + output_states + + if output_hidden_states: + self_attention_hidden_states = (intermediate_states, output_states) + + # Reshape hidden_states to group by image_pairs : + # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim) + # Flip dimension 1 to perform cross attention : + # (image0, image1) -> (image1, image0) + # Reshape back to original shape : + # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim) + encoder_hidden_states = ( + self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim) + .flip(1) + .reshape(batch_size, num_keypoints, descriptor_dim) + ) + # Same for mask + encoder_attention_mask = ( + attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints) + if attention_mask is not None + else None + ) + + # Cross attention block + cross_attention_output, cross_attentions = self.cross_attention( + self_attention_descriptors, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1) + cross_output_states = self.cross_mlp(cross_intermediate_states) + descriptors = self_attention_descriptors + cross_output_states + + if output_hidden_states: + cross_attention_hidden_states = (cross_intermediate_states, cross_output_states) + all_hidden_states = ( + all_hidden_states + + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + self_attention_hidden_states + + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),) + + cross_attention_hidden_states + ) + + if output_attentions: + all_attentions = all_attentions + (self_attentions,) + (cross_attentions,) + + return descriptors, all_hidden_states, all_attentions + + +def sigmoid_log_double_softmax( + similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor +) -> torch.Tensor: + """create the log assignment matrix from logits and similarity""" + batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape + certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2) + scores0 = nn.functional.log_softmax(similarity, 2) + scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2) + scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0) + scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties + scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1)) + scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1)) + return scores + + +class LightGlueMatchAssignmentLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.descriptor_dim = config.descriptor_dim + self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True) + self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True) + + def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + batch_size, num_keypoints, descriptor_dim = descriptors.shape + # Final projection and similarity computation + m_descriptors = self.final_projection(descriptors) + m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25 + m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim) + m_descriptors0 = m_descriptors[:, 0] + m_descriptors1 = m_descriptors[:, 1] + similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2) + if mask is not None: + mask = mask.reshape(batch_size // 2, 2, num_keypoints) + mask0 = mask[:, 0].unsqueeze(-1) + mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2) + mask = mask0 * mask1 + similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min) + + # Compute matchability of descriptors + matchability = self.matchability(descriptors) + matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1) + matchability_0 = matchability[:, 0] + matchability_1 = matchability[:, 1] + + # Compute scores from similarity and matchability + scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1) + return scores + + def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor: + """Get matchability of descriptors as a probability""" + matchability = self.matchability(descriptors) + matchability = nn.functional.sigmoid(matchability).squeeze(-1) + return matchability + + +class LightGlueTokenConfidenceLayer(nn.Module): + def __init__(self, config: LightGlueConfig): + super().__init__() + + self.token = nn.Linear(config.descriptor_dim, 1) + + def forward(self, descriptors: torch.Tensor) -> torch.Tensor: + token = self.token(descriptors.detach()) + token = nn.functional.sigmoid(token).squeeze(-1) + return token + + +@auto_docstring +class LightGluePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = LightGlueConfig + base_model_prefix = "lightglue" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: + """obtain matches from a score matrix [Bx M+1 x N+1]""" + batch_size, _, _ = scores.shape + # For each keypoint, get the best match + max0 = scores[:, :-1, :-1].max(2) + max1 = scores[:, :-1, :-1].max(1) + matches0 = max0.indices + matches1 = max1.indices + + # Mutual check for matches + indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None] + indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None] + mutual0 = indices0 == matches1.gather(1, matches0) + mutual1 = indices1 == matches0.gather(1, matches1) + + # Get matching scores and filter based on mutual check and thresholding + max0 = max0.values.exp() + zero = max0.new_tensor(0) + matching_scores0 = torch.where(mutual0, max0, zero) + matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero) + valid0 = mutual0 & (matching_scores0 > threshold) + valid1 = mutual1 & valid0.gather(1, matches1) + + # Filter matches based on mutual check and thresholding of scores + matches0 = torch.where(valid0, matches0, -1) + matches1 = torch.where(valid1, matches1, -1) + matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1) + matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1) + + return matches, matching_scores + + +def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + Normalize keypoints locations based on image image_shape + + Args: + keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`): + Keypoints locations in (x, y) format. + height (`int`): + Image height. + width (`int`): + Image width. + + Returns: + Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`). + """ + size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None] + shift = size / 2 + scale = size.max(-1).values / 2 + keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None] + return keypoints + + +@auto_docstring( + custom_intro=""" + LightGlue model taking images as inputs and outputting the matching of them. + """ +) +class LightGlueForKeypointMatching(LightGluePreTrainedModel): + """ + LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as + SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient. + It consists of : + 1. Keypoint Encoder + 2. A Graph Neural Network with self and cross attention layers + 3. Matching Assignment layers + + The correspondence ids use -1 to indicate non-matching points. + + Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed. + In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf + """ + + def __init__(self, config: LightGlueConfig): + super().__init__(config) + + self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config) + + self.descriptor_dim = config.descriptor_dim + self.num_layers = config.num_hidden_layers + self.filter_threshold = config.filter_threshold + self.depth_confidence = config.depth_confidence + self.width_confidence = config.width_confidence + + if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim: + self.input_projection = nn.Linear( + config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True + ) + else: + self.input_projection = nn.Identity() + + self.positional_encoder = LightGluePositionalEncoder(config) + + self.transformer_layers = nn.ModuleList( + [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)] + ) + self.match_assignment_layers = nn.ModuleList( + [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.token_confidence = nn.ModuleList( + [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)] + ) + + self.post_init() + + def _get_confidence_threshold(self, layer_index: int) -> float: + """scaled confidence threshold for a given layer""" + threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers) + return np.clip(threshold, 0, 1) + + def _keypoint_processing( + self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + descriptors = descriptors.detach().contiguous() + projected_descriptors = self.input_projection(descriptors) + keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states) + return projected_descriptors, keypoint_encoding_output + + def _get_early_stopped_image_pairs( + self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor + ) -> torch.Tensor: + """evaluate whether we should stop inference based on the confidence of the keypoints""" + batch_size, _ = mask.shape + if layer_index < self.num_layers - 1: + # If the current layer is not the last layer, we compute the confidence of the keypoints and check + # if we should stop the forward pass through the transformer layers for each pair of images. + keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1) + keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1) + threshold = self._get_confidence_threshold(layer_index) + ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points + early_stopped_pairs = ratio_confident > self.depth_confidence + else: + # If the current layer is the last layer, we stop the forward pass through the transformer layers for + # all pairs of images. + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + return early_stopped_pairs + + def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None): + if early_stops is not None: + descriptors = descriptors[early_stops] + mask = mask[early_stops] + scores = self.match_assignment_layers[layer_index](descriptors, mask) + matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold) + return matches, matching_scores + + def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor: + """mask points which should be removed""" + keep = scores > (1 - self.width_confidence) + if confidences is not None: # Low-confidence points are never pruned. + keep |= confidences <= self._get_confidence_threshold(layer_index) + return keep + + def _do_layer_keypoint_pruning( + self, + descriptors: torch.Tensor, + keypoints: torch.Tensor, + mask: torch.Tensor, + indices: torch.Tensor, + prune_output: torch.Tensor, + keypoint_confidences: torch.Tensor, + layer_index: int, + ): + """ + For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the + descriptors. + """ + batch_size, _, _ = descriptors.shape + descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors) + pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index) + pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False)) + + # For each image, we extract the pruned indices and the corresponding descriptors and keypoints. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = ( + [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)] + for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices] + ) + for i in range(batch_size): + prune_output[i, pruned_indices[i]] += 1 + + # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch. + pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = ( + pad_sequence(pruned_tensor, batch_first=True) + for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask] + ) + pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1) + pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1) + + return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output + + def _concat_early_stopped_outputs( + self, + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ): + early_stops_indices = torch.stack(early_stops_indices) + matches, final_pruned_keypoints_indices = ( + pad_sequence(tensor, batch_first=True, padding_value=-1) + for tensor in [matches, final_pruned_keypoints_indices] + ) + matching_scores, final_pruned_keypoints_iterations = ( + pad_sequence(tensor, batch_first=True, padding_value=0) + for tensor in [matching_scores, final_pruned_keypoints_iterations] + ) + matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = ( + tensor[early_stops_indices] + for tensor in [ + matches, + matching_scores, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + ] + ) + return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores + + def _do_final_keypoint_pruning( + self, + indices: torch.Tensor, + matches: torch.Tensor, + matching_scores: torch.Tensor, + num_keypoints: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to + # have tensors from + batch_size, _ = indices.shape + indices, matches, matching_scores = ( + tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores] + ) + indices0 = indices[:, 0] + indices1 = indices[:, 1] + matches0 = matches[:, 0] + matches1 = matches[:, 1] + matching_scores0 = matching_scores[:, 0] + matching_scores1 = matching_scores[:, 1] + + # Prepare final matches and matching scores + _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype) + _matching_scores = torch.zeros( + (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype + ) + # Fill the matches and matching scores for each image pair + for i in range(batch_size // 2): + _matches[i, 0, indices0[i]] = torch.where( + matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0)) + ) + _matches[i, 1, indices1[i]] = torch.where( + matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0)) + ) + _matching_scores[i, 0, indices0[i]] = matching_scores0[i] + _matching_scores[i, 1, indices1[i]] = matching_scores1[i] + return _matches, _matching_scores + + def _match_image_pair( + self, + keypoints: torch.Tensor, + descriptors: torch.Tensor, + height: int, + width: int, + mask: torch.Tensor = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if keypoints.shape[2] == 0: # no keypoints + shape = keypoints.shape[:-1] + return ( + keypoints.new_full(shape, -1, dtype=torch.int), + keypoints.new_zeros(shape), + keypoints.new_zeros(shape), + all_hidden_states, + all_attentions, + ) + + device = keypoints.device + batch_size, _, initial_num_keypoints, _ = keypoints.shape + num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1) + # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2) + keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2) + mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None + descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim) + image_indices = torch.arange(batch_size * 2, device=device) + # Keypoint normalization + keypoints = normalize_keypoints(keypoints, height, width) + + descriptors, keypoint_encoding_output = self._keypoint_processing( + descriptors, keypoints, output_hidden_states=output_hidden_states + ) + + keypoints = keypoint_encoding_output[0] + + # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the + # keypoints is above a certain threshold. + do_early_stop = self.depth_confidence > 0 + # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of + # the keypoints is below a certain threshold. + do_keypoint_pruning = self.width_confidence > 0 + + early_stops_indices = [] + matches = [] + matching_scores = [] + final_pruned_keypoints_indices = [] + final_pruned_keypoints_iterations = [] + + pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1) + pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices) + + for layer_index in range(self.num_layers): + input_shape = descriptors.size() + if mask is not None: + extended_attention_mask = self.get_extended_attention_mask(mask, input_shape) + else: + extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device) + layer_output = self.transformer_layers[layer_index]( + descriptors, + keypoints, + attention_mask=extended_attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + ) + descriptors, hidden_states, attention = layer_output + if output_hidden_states: + all_hidden_states = all_hidden_states + hidden_states + if output_attentions: + all_attentions = all_attentions + attention + + if do_early_stop: + if layer_index < self.num_layers - 1: + # Get the confidence of the keypoints for the current layer + keypoint_confidences = self.token_confidence[layer_index](descriptors) + + # Determine which pairs of images should be early stopped based on the confidence of the keypoints for + # the current layer. + early_stopped_pairs = self._get_early_stopped_image_pairs( + keypoint_confidences, layer_index, mask, num_points=num_points_per_pair + ) + else: + # Early stopping always occurs at the last layer + early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool) + + if torch.any(early_stopped_pairs): + # If a pair of images is considered early stopped, we compute the matches for the remaining + # keypoints and stop the forward pass through the transformer layers for this pair of images. + early_stops = early_stopped_pairs.repeat_interleave(2) + early_stopped_image_indices = image_indices[early_stops] + early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching( + descriptors, mask, layer_index, early_stops=early_stops + ) + early_stops_indices.extend(list(early_stopped_image_indices)) + matches.extend(list(early_stopped_matches)) + matching_scores.extend(list(early_stopped_matching_scores)) + if do_keypoint_pruning: + final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops])) + final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops])) + + # Remove image pairs that have been early stopped from the forward pass + num_points_per_pair = num_points_per_pair[~early_stopped_pairs] + descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple( + ( + tensor[~early_stops] + for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices] + ) + ) + keypoints = (keypoints_0, keypoint_1) + if do_keypoint_pruning: + pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple( + ( + tensor[~early_stops] + for tensor in [ + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + ] + ) + ) + # If all pairs of images are early stopped, we stop the forward pass through the transformer + # layers for all pairs of images. + if torch.all(early_stopped_pairs): + break + + if do_keypoint_pruning: + # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of + # the keypoints is below a certain threshold. + descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = ( + self._do_layer_keypoint_pruning( + descriptors, + keypoints, + mask, + pruned_keypoints_indices, + pruned_keypoints_iterations, + keypoint_confidences, + layer_index, + ) + ) + + if do_early_stop and do_keypoint_pruning: + # Concatenate early stopped outputs together and perform final keypoint pruning + final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = ( + self._concat_early_stopped_outputs( + early_stops_indices, + final_pruned_keypoints_indices, + final_pruned_keypoints_iterations, + matches, + matching_scores, + ) + ) + matches, matching_scores = self._do_final_keypoint_pruning( + final_pruned_keypoints_indices, + matches, + matching_scores, + initial_num_keypoints, + ) + else: + matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1) + final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers + + final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape( + batch_size, 2, initial_num_keypoints + ) + + return ( + matches, + matching_scores, + final_pruned_keypoints_iterations, + all_hidden_states, + all_attentions, + ) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> Union[Tuple, LightGlueKeypointMatchingOutput]: + loss = None + if labels is not None: + raise ValueError("LightGlue is not trainable, no labels should be provided.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + if pixel_values.ndim != 5 or pixel_values.size(1) != 2: + raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)") + + batch_size, _, channels, height, width = pixel_values.shape + pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width) + keypoint_detections = self.keypoint_detector(pixel_values) + + keypoints, _, descriptors, mask = keypoint_detections[:4] + keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values) + descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values) + mask = mask.reshape(batch_size, 2, -1) + + absolute_keypoints = keypoints.clone() + absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width + absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height + + matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair( + absolute_keypoints, + descriptors, + height, + width, + mask=mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + return LightGlueKeypointMatchingOutput( + loss=loss, + matches=matches, + matching_scores=matching_scores, + keypoints=keypoints, + prune=prune, + mask=mask, + hidden_states=hidden_states, + attentions=attentions, + ) + + +__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"] diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py index c2e1f93626..e39c4f933b 100644 --- a/src/transformers/models/superglue/image_processing_superglue.py +++ b/src/transformers/models/superglue/image_processing_superglue.py @@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np -from ... import is_torch_available, is_vision_available from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict from ...image_transforms import resize, to_channel_dimension_format from ...image_utils import ( @@ -29,7 +28,9 @@ from ...image_utils import ( infer_channel_dimension_format, is_pil_image, is_scaled_image, + is_torch_available, is_valid_image, + is_vision_available, to_numpy_array, valid_images, validate_preprocess_arguments, diff --git a/src/transformers/models/superpoint/modeling_superpoint.py b/src/transformers/models/superpoint/modeling_superpoint.py index a0077b3e04..63d717add7 100644 --- a/src/transformers/models/superpoint/modeling_superpoint.py +++ b/src/transformers/models/superpoint/modeling_superpoint.py @@ -253,7 +253,7 @@ class SuperPointInterestPointDecoder(nn.Module): keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints) # Convert (y, x) to (x, y) - keypoints = torch.flip(keypoints, [1]).float() + keypoints = torch.flip(keypoints, [1]).to(scores.dtype) return keypoints, scores diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 564213d720..386a85228f 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -179,6 +179,7 @@ from .import_utils import ( is_librosa_available, is_liger_kernel_available, is_lomo_available, + is_matplotlib_available, is_mlx_available, is_natten_available, is_ninja_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ebd1ae9ef1..420955e998 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -225,6 +225,7 @@ _triton_available = _is_package_available("triton") _spqr_available = _is_package_available("spqr_quant") _rich_available = _is_package_available("rich") _kernels_available = _is_package_available("kernels") +_matplotlib_available = _is_package_available("matplotlib") _torch_version = "N/A" _torch_available = False @@ -1443,6 +1444,10 @@ def is_rich_available(): return _rich_available +def is_matplotlib_available(): + return _matplotlib_available + + def check_torch_load_is_safe(): if not is_torch_greater_or_equal("2.6"): raise ValueError( diff --git a/tests/models/lightglue/__init__.py b/tests/models/lightglue/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/lightglue/test_image_processing_lightglue.py b/tests/models/lightglue/test_image_processing_lightglue.py new file mode 100644 index 0000000000..3d01acf469 --- /dev/null +++ b/tests/models/lightglue/test_image_processing_lightglue.py @@ -0,0 +1,96 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from tests.models.superglue.test_image_processing_superglue import ( + SuperGlueImageProcessingTest, + SuperGlueImageProcessingTester, +) +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_vision_available + + +if is_torch_available(): + import numpy as np + import torch + + from transformers.models.lightglue.modeling_lightglue import LightGlueKeypointMatchingOutput + +if is_vision_available(): + from transformers import LightGlueImageProcessor + + +def random_array(size): + return np.random.randint(255, size=size) + + +def random_tensor(size): + return torch.rand(size) + + +class LightGlueImageProcessingTester(SuperGlueImageProcessingTester): + """Tester for LightGlueImageProcessor""" + + def __init__( + self, + parent, + batch_size=6, + num_channels=3, + image_size=18, + min_resolution=30, + max_resolution=400, + do_resize=True, + size=None, + do_grayscale=True, + ): + super().__init__( + parent, batch_size, num_channels, image_size, min_resolution, max_resolution, do_resize, size, do_grayscale + ) + + def prepare_keypoint_matching_output(self, pixel_values): + """Prepare a fake output for the keypoint matching model with random matches between 50 keypoints per image.""" + max_number_keypoints = 50 + batch_size = len(pixel_values) + mask = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int) + keypoints = torch.zeros((batch_size, 2, max_number_keypoints, 2)) + matches = torch.full((batch_size, 2, max_number_keypoints), -1, dtype=torch.int) + scores = torch.zeros((batch_size, 2, max_number_keypoints)) + prune = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int) + for i in range(batch_size): + random_number_keypoints0 = np.random.randint(10, max_number_keypoints) + random_number_keypoints1 = np.random.randint(10, max_number_keypoints) + random_number_matches = np.random.randint(5, min(random_number_keypoints0, random_number_keypoints1)) + mask[i, 0, :random_number_keypoints0] = 1 + mask[i, 1, :random_number_keypoints1] = 1 + keypoints[i, 0, :random_number_keypoints0] = torch.rand((random_number_keypoints0, 2)) + keypoints[i, 1, :random_number_keypoints1] = torch.rand((random_number_keypoints1, 2)) + random_matches_indices0 = torch.randperm(random_number_keypoints1, dtype=torch.int)[:random_number_matches] + random_matches_indices1 = torch.randperm(random_number_keypoints0, dtype=torch.int)[:random_number_matches] + matches[i, 0, random_matches_indices1] = random_matches_indices0 + matches[i, 1, random_matches_indices0] = random_matches_indices1 + scores[i, 0, random_matches_indices1] = torch.rand((random_number_matches,)) + scores[i, 1, random_matches_indices0] = torch.rand((random_number_matches,)) + return LightGlueKeypointMatchingOutput( + mask=mask, keypoints=keypoints, matches=matches, matching_scores=scores, prune=prune + ) + + +@require_torch +@require_vision +class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase): + image_processing_class = LightGlueImageProcessor if is_vision_available() else None + + def setUp(self) -> None: + super().setUp() + self.image_processor_tester = LightGlueImageProcessingTester(self) diff --git a/tests/models/lightglue/test_modeling_lightglue.py b/tests/models/lightglue/test_modeling_lightglue.py new file mode 100644 index 0000000000..20d9f2ef61 --- /dev/null +++ b/tests/models/lightglue/test_modeling_lightglue.py @@ -0,0 +1,584 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import unittest + +from datasets import load_dataset + +from transformers.models.lightglue.configuration_lightglue import LightGlueConfig +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import cached_property, is_torch_available, is_vision_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor + + +if is_torch_available(): + import torch + + from transformers import LightGlueForKeypointMatching + +if is_vision_available(): + from transformers import AutoImageProcessor + + +class LightGlueModelTester: + def __init__( + self, + parent, + batch_size=2, + image_width=80, + image_height=60, + keypoint_detector_config={ + "encoder_hidden_sizes": [32, 32, 64], + "decoder_hidden_size": 64, + "keypoint_decoder_dim": 65, + "descriptor_decoder_dim": 64, + "keypoint_threshold": 0.005, + "max_keypoints": 256, + "nms_radius": 4, + "border_removal_distance": 4, + }, + descriptor_dim: int = 64, + num_layers: int = 2, + num_heads: int = 4, + depth_confidence: float = 1.0, + width_confidence: float = 1.0, + filter_threshold: float = 0.1, + matching_threshold: float = 0.0, + ): + self.parent = parent + self.batch_size = batch_size + self.image_width = image_width + self.image_height = image_height + + self.keypoint_detector_config = keypoint_detector_config + self.descriptor_dim = descriptor_dim + self.num_layers = num_layers + self.num_heads = num_heads + self.depth_confidence = depth_confidence + self.width_confidence = width_confidence + self.filter_threshold = filter_threshold + self.matching_threshold = matching_threshold + + def prepare_config_and_inputs(self): + # LightGlue expects a grayscale image as input + pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width]) + config = self.get_config() + return config, pixel_values + + def get_config(self): + return LightGlueConfig( + keypoint_detector_config=self.keypoint_detector_config, + descriptor_dim=self.descriptor_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + depth_confidence=self.depth_confidence, + width_confidence=self.width_confidence, + filter_threshold=self.filter_threshold, + matching_threshold=self.matching_threshold, + attn_implementation="eager", + ) + + def create_and_check_model(self, config, pixel_values): + model = LightGlueForKeypointMatching(config=config) + model.to(torch_device) + model.eval() + result = model(pixel_values) + maximum_num_matches = result.mask.shape[-1] + self.parent.assertEqual( + result.keypoints.shape, + (self.batch_size, 2, maximum_num_matches, 2), + ) + self.parent.assertEqual( + result.matches.shape, + (self.batch_size, 2, maximum_num_matches), + ) + self.parent.assertEqual( + result.matching_scores.shape, + (self.batch_size, 2, maximum_num_matches), + ) + self.parent.assertEqual( + result.prune.shape, + (self.batch_size, 2, maximum_num_matches), + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + inputs_dict = {"pixel_values": pixel_values} + return config, inputs_dict + + +@require_torch +class LightGlueModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (LightGlueForKeypointMatching,) if is_torch_available() else () + all_generative_model_classes = () if is_torch_available() else () + + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + has_attentions = True + + def setUp(self): + self.model_tester = LightGlueModelTester(self) + self.config_tester = ConfigTester(self, config_class=LightGlueConfig, has_text_modality=False, hidden_size=37) + + def test_config(self): + self.config_tester.create_and_test_config_to_json_string() + self.config_tester.create_and_test_config_to_json_file() + self.config_tester.create_and_test_config_from_and_save_pretrained() + self.config_tester.create_and_test_config_with_num_labels() + self.config_tester.check_config_can_be_init_without_params() + self.config_tester.check_config_arguments_init() + + @unittest.skip(reason="LightGlueForKeypointMatching does not use inputs_embeds") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching does not support input and output embeddings") + def test_model_get_set_embeddings(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching does not use feedforward chunking") + def test_feed_forward_chunking(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip(reason="LightGlueForKeypointMatching is not trainable") + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + @unittest.skip(reason="LightGlue does not output any loss term in the forward pass") + def test_retain_grad_hidden_states_attentions(self): + pass + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_hidden_states_output(self): + def check_hidden_states_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + hidden_states = outputs.hidden_states + maximum_num_matches = outputs.mask.shape[-1] + + hidden_states_sizes = [ + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim * 2, + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim, + self.model_tester.descriptor_dim * 2, + self.model_tester.descriptor_dim, + ] * self.model_tester.num_layers + + for i, hidden_states_size in enumerate(hidden_states_sizes): + self.assertListEqual( + list(hidden_states[i].shape[-2:]), + [maximum_num_matches, hidden_states_size], + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_hidden_states"] = True + check_hidden_states_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_hidden_states"] + config.output_hidden_states = True + + check_hidden_states_output(inputs_dict, config, model_class) + + def test_attention_outputs(self): + def check_attention_output(inputs_dict, config, model_class): + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + attentions = outputs.attentions + maximum_num_matches = outputs.mask.shape[-1] + + expected_attention_shape = [self.model_tester.num_heads, maximum_num_matches, maximum_num_matches] + + for i, attention in enumerate(attentions): + self.assertListEqual( + list(attention.shape[-3:]), + expected_attention_shape, + ) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + check_attention_output(inputs_dict, config, model_class) + + # check that output_hidden_states also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + + check_attention_output(inputs_dict, config, model_class) + + @slow + def test_model_from_pretrained(self): + from_pretrained_ids = ["ETH-CVG/lightglue_superpoint"] + for model_name in from_pretrained_ids: + model = LightGlueForKeypointMatching.from_pretrained(model_name) + self.assertIsNotNone(model) + + # Copied from tests.models.superglue.test_modeling_superglue.SuperGlueModelTest.test_forward_labels_should_be_none + def test_forward_labels_should_be_none(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + model_inputs = self._prepare_for_class(inputs_dict, model_class) + # Provide an arbitrary sized Tensor as labels to model inputs + model_inputs["labels"] = torch.rand((128, 128)) + + with self.assertRaises(ValueError) as cm: + model(**model_inputs) + self.assertEqual(ValueError, cm.exception.__class__) + + +def prepare_imgs(): + dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train") + image0 = dataset[0]["image"] + image1 = dataset[1]["image"] + image2 = dataset[2]["image"] + # [image1, image1] on purpose to test the model early stopping + return [[image2, image0], [image1, image1]] + + +@require_torch +@require_vision +class LightGlueModelIntegrationTest(unittest.TestCase): + @cached_property + def default_image_processor(self): + return AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint") if is_vision_available() else None + + @slow + def test_inference(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", attn_implementation="eager" + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values0 = outputs.matches[0, 0, 10:30] + predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30] + + predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_matches_values1 = outputs.matches[1, 0, 10:30] + predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] + + expected_number_of_matches0 = 140 + expected_matches_values0 = torch.tensor( + [14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], + dtype=torch.int64, + device=torch_device, + ) + expected_matching_scores_values0 = torch.tensor( + [0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583], + device=torch_device, + ) + + expected_number_of_matches1 = 866 + expected_matches_values1 = torch.tensor( + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], + dtype=torch.int64, + device=torch_device, + ) + expected_matching_scores_values1 = torch.tensor( + [ + 0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496, + 0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923, + ], + device=torch_device + ) # fmt:skip + + # expected_early_stopping_layer = 2 + # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() + # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) + # self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches) + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) + self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2)) + < 4 + ) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2)) + < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) + self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4) + + @slow + def test_inference_without_early_stop(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", attn_implementation="eager", depth_confidence=1.0 + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values0 = outputs.matches[0, 0, 10:30] + predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30] + + predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_matches_values1 = outputs.matches[1, 0, 10:30] + predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] + + expected_number_of_matches0 = 134 + expected_matches_values0 = torch.tensor( + [-1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values0 = torch.tensor( + [0.0083, 0, 0.2022, 0.0621, 0, 0.0828, 0, 0, 0.0003, 0, 0, 0, 0.0960, 0, 0, 0.6940, 0, 0.7167, 0, 0.1512] + ).to(torch_device) + + expected_number_of_matches1 = 862 + expected_matches_values1 = torch.tensor( + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values1 = torch.tensor( + [ + 0.4772, + 0.3781, + 0.0631, + 0.9559, + 0.8746, + 0.9271, + 0.4882, + 0.5406, + 0.9439, + 0.1526, + 0.5028, + 0.4107, + 0.5591, + 0.9130, + 0.7572, + 0.0302, + 0.4532, + 0.0893, + 0.9490, + 0.4880, + ] + ).to(torch_device) + + # expected_early_stopping_layer = 2 + # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() + # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) + # self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches) + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) + self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2)) + < 4 + ) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2)) + < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) + self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4) + + @slow + def test_inference_without_early_stop_and_keypoint_pruning(self): + model = LightGlueForKeypointMatching.from_pretrained( + "ETH-CVG/lightglue_superpoint", + attn_implementation="eager", + depth_confidence=1.0, + width_confidence=1.0, + ).to(torch_device) + preprocessor = self.default_image_processor + images = prepare_imgs() + inputs = preprocessor(images=images, return_tensors="pt").to(torch_device) + with torch.no_grad(): + outputs = model(**inputs, output_hidden_states=True, output_attentions=True) + + predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item() + predicted_matches_values0 = outputs.matches[0, 0, 10:30] + predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30] + + predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item() + predicted_matches_values1 = outputs.matches[1, 0, 10:30] + predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30] + + expected_number_of_matches0 = 144 + expected_matches_values0 = torch.tensor( + [-1, -1, 17, -1, -1, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values0 = torch.tensor( + [ + 0.0699, + 0.0302, + 0.3356, + 0.0820, + 0, + 0.2266, + 0, + 0, + 0.0241, + 0, + 0, + 0, + 0.1674, + 0, + 0, + 0.8114, + 0, + 0.8120, + 0, + 0.2936, + ] + ).to(torch_device) + + expected_number_of_matches1 = 862 + expected_matches_values1 = torch.tensor( + [10, 11, -1, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, -1, 26, -1, 28, 29], dtype=torch.int64 + ).to(torch_device) + expected_matching_scores_values1 = torch.tensor( + [ + 0.4772, + 0.3781, + 0.0631, + 0.9559, + 0.8746, + 0.9271, + 0.4882, + 0.5406, + 0.9439, + 0.1526, + 0.5028, + 0.4107, + 0.5591, + 0.9130, + 0.7572, + 0.0302, + 0.4532, + 0.0893, + 0.9490, + 0.4880, + ] + ).to(torch_device) + + # expected_early_stopping_layer = 2 + # predicted_early_stopping_layer = torch.max(outputs.prune[1]).item() + # self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer) + # self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches) + + """ + Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies + on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this + specific test example). The consequence of having different number of keypoints is that the number of matches + will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less + match. The matching scores will also be different, as the keypoints are different. The checks here are less + strict to account for these inconsistencies. + Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the + expected values, individually. Here, the tolerance of the number of values changing is set to 2. + + This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787) + Such CUDA inconsistencies can be found + [here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300) + """ + + self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4) + self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2)) + < 4 + ) + self.assertTrue( + torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2)) + < 4 + ) + self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4) + self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)