Add LightGlue model (#31718)
* init * chore: various changes to LightGlue * chore: various changes to LightGlue * chore: various changes to LightGlue * chore: various changes to LightGlue * Fixed dynamo bug and image padding tests * refactor: applied refactoring changes from SuperGlue's concat, batch and stack functions to LightGlue file * tests: removed sdpa support and changed expected values * chore: added some docs and refactoring * chore: fixed copy to superpoint.image_processing_superpoint.convert_to_grayscale * feat: adding batch implementation * feat: added validation for preprocess and post process method to LightGlueImageProcessor * chore: changed convert_lightglue_to_hf script to comply with new standard * chore: changed lightglue test values to match new lightglue config pushed to hub * chore: simplified convert_lightglue_to_hf conversion map * feat: adding batching implementation * chore: make style * feat: added threshold to post_process_keypoint_matching method * fix: added missing instructions that turns keypoints back to absolute coordinate before matching forward * fix: added typehint and docs * chore: make style * [run-slow] lightglue * fix: add matches different from -1 to compute valid matches in post_process_keypoint_matching * tests: added CUDA proof tests similar to SuperGlue * chore: various changes to modeling_lightglue.py - Added "Copies from" statements for copied functions from modeling_superglue.py - Added missing docstrings - Removed unused functions or classes - Removed unnecessary statements - Added missing typehints - Added comments to the main forward method * chore: various changes to convert_lightglue_to_hf.py - Added model saving - Added model reloading * chore: fixed imports in lightglue files * [run-slow] lightglue * chore: make style * [run-slow] lightglue * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * [run-slow] lightglue * chore: Applied some suggestions from review - Added missing typehints - Refactor "cuda" to device variable - Variable renaming - LightGlue output order changed - Make style * fix: added missing grayscale argument in image processor in case use of SuperPoint keypoint detector * fix: changed lightglue HF repo to lightglue_superpoint with grayscale default to True * refactor: make keypoints `(batch_size, num_keypoints, keypoint_dim)` through forward and unsqueeze only before attention layer * refactor: refactor do_layer_keypoint_pruning * tests: added tests with no early stop and keypoint pruning * refactor: various refactoring to modeling_lightglue.py - Removed unused functions - Renamed variables for consistency - Added comments for clarity - Set methods to private in LightGlueForKeypointMatching - Replaced tensor initialization to list then concatenation - Used more pythonic list comprehension for repetitive instructions * refactor: added comments and renamed filter_matches to get_matches_from_scores * tests: added copied from statement with superglue tests * docs: added comment to prepare_keypoint_matching_output function in tests * [run-slow] lightglue * refactor: reordered _concat_early_stopped_outputs in LightGlue class * [run-slow] lightglue * docs: added lightglue.md model doc * docs: added Optional typehint to LightGlueKeypointMatchingOutput * chore: removed pad_images function * chore: set do_grayscale default value to True in LightGlueImageProcessor * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * docs: added missing LightGlueConfig typehint in nn.Module __init__ methods * docs: removed unnecessary code in docs * docs: import SuperPointConfig only from a TYPE_CHECKING context * chore: use PretrainedConfig arguments `num_hidden_layers` and `num_attention_heads` instead of `num_layers` and `num_heads` * chore: added organization as arg in convert_lightglue_to_hf.py script * refactor: set device variable * chore: added "gelu" in LightGlueConfig as hidden_act parameter * docs: added comments to reshape.flip.reshape instruction to perform cross attention * refactor: used batched inference for keypoint detector forward pass * fix: added fix for SDPA tests * docs: fixed docstring for LightGlueImageProcessor * [run-slow] lightglue * refactor: removed unused line * refactor: added missing arguments in LightGlueConfig init method * docs: added missing LightGlueConfig typehint in init methods * refactor: added checkpoint url as default variable to verify models output only if it is the default url * fix: moved print message inside if statement * fix: added log assignment r removal in convert script * fix: got rid of confidence_thresholds as registered buffers * refactor: applied suggestions from SuperGlue PR * docs: changed copyright to 2025 * refactor: modular LightGlue * fix: removed unnecessary import * feat: added plot_keypoint_matching method to LightGlueImageProcessor with matplotlib soft dependency * fix: added missing import error for matplotlib * Updated convert script to push on ETH org * fix: added missing licence * fix: make fix-copies * refactor: use cohere apply_rotary_pos_emb function * fix: update model references to use ETH-CVG/lightglue_superpoint * refactor: add and use intermediate_size attribute in config to inherit CLIPMLP for LightGlueMLP * refactor: explicit variables instead of slicing * refactor: use can_return_tuple decorator in LightGlue model * fix: make fix-copies * docs: Update model references in `lightglue.md` to use the correct pretrained model from ETH-CVG * Refactor LightGlue configuration and processing classes - Updated type hints for `keypoint_detector_config` in `LightGlueConfig` to use `SuperPointConfig` directly. - Changed `size` parameter in `LightGlueImageProcessor` to be optional. - Modified `position_embeddings` in `LightGlueAttention` and `LightGlueAttentionBlock` to be optional tuples. - Cleaned up import statements across multiple files for better readability and consistency. * refactor: Update LightGlue configuration to enforce eager attention implementation - Added `attn_implementation="eager"` to `keypoint_detector_config` in `LightGlueConfig` and `LightGlueAttention` classes. - Removed unnecessary logging related to attention implementation fallback. - Cleaned up import statements for better readability. * refactor: renamed message into attention_output * fix: ensure device compatibility in LightGlueMatchAssignmentLayer descriptor normalization - Updated the normalization of `m_descriptors` to use the correct device for the tensor, ensuring compatibility across different hardware setups. * refactor: removed Conv layers from init_weights since LightGlue doesn't have any * refactor: replace add_start_docstrings with auto_docstring in LightGlue models - Updated LightGlue model classes to utilize the new auto_docstring utility for automatic documentation generation. - Removed legacy docstring handling to streamline the code and improve maintainability. * refactor: simplify LightGlue image processing tests by inheriting from SuperGlue - Refactored `LightGlueImageProcessingTester` and `LightGlueImageProcessingTest` to inherit from their SuperGlue counterparts, reducing code duplication. - Removed redundant methods and properties, streamlining the test setup and improving maintainability. * test: forced eager attention implementation to LightGlue model tests - Updated `LightGlueModelTester` to include `attn_implementation="eager"` in the model configuration. - This change aligns the test setup with the recent updates in LightGlue configuration for eager attention. * refactor: update LightGlue model references * fix: import error * test: enhance LightGlue image processing tests with setup method - Added a setup method in `LightGlueImageProcessingTest` to initialize `LightGlueImageProcessingTester`. - Included a docstring for `LightGlueImageProcessingTester` to clarify its purpose. * refactor: added LightGlue image processing implementation to modular file * refactor: moved attention blocks into the transformer layer * fix: added missing import * fix: added missing import in __all__ variable * doc: added comment about enforcing eager attention because of SuperPoint * refactor: added SuperPoint eager attention comment and moved functions to the closest they are used --------- Co-authored-by: Steven Bucaille <steven.bucaille@buawei.com> Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
This commit is contained in:
@@ -743,6 +743,8 @@
|
|||||||
title: ImageGPT
|
title: ImageGPT
|
||||||
- local: model_doc/levit
|
- local: model_doc/levit
|
||||||
title: LeViT
|
title: LeViT
|
||||||
|
- local: model_doc/lightglue
|
||||||
|
title: LightGlue
|
||||||
- local: model_doc/mask2former
|
- local: model_doc/mask2former
|
||||||
title: Mask2Former
|
title: Mask2Former
|
||||||
- local: model_doc/maskformer
|
- local: model_doc/maskformer
|
||||||
|
|||||||
104
docs/source/en/model_doc/lightglue.md
Normal file
104
docs/source/en/model_doc/lightglue.md
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
|
||||||
|
Licensed under the MIT License; you may not use this file except in compliance with
|
||||||
|
the License.
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||||
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
|
||||||
|
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||||
|
rendered properly in your Markdown viewer.
|
||||||
|
|
||||||
|
|
||||||
|
-->
|
||||||
|
|
||||||
|
# LightGlue
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The LightGlue model was proposed in [LightGlue: Local Feature Matching at Light Speed](https://arxiv.org/abs/2306.13643)
|
||||||
|
by Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys.
|
||||||
|
|
||||||
|
Similar to [SuperGlue](https://huggingface.co/magic-leap-community/superglue_outdoor), this model consists of matching
|
||||||
|
two sets of local features extracted from two images, its goal is to be faster than SuperGlue. Paired with the
|
||||||
|
[SuperPoint model](https://huggingface.co/magic-leap-community/superpoint), it can be used to match two images and
|
||||||
|
estimate the pose between them. This model is useful for tasks such as image matching, homography estimation, etc.
|
||||||
|
|
||||||
|
The abstract from the paper is the following:
|
||||||
|
|
||||||
|
*We introduce LightGlue, a deep neural network that learns to match local features across images. We revisit multiple
|
||||||
|
design decisions of SuperGlue, the state of the art in sparse matching, and derive simple but effective improvements.
|
||||||
|
Cumulatively, they make LightGlue more efficient - in terms of both memory and computation, more accurate, and much
|
||||||
|
easier to train. One key property is that LightGlue is adaptive to the difficulty of the problem: the inference is much
|
||||||
|
faster on image pairs that are intuitively easy to match, for example because of a larger visual overlap or limited
|
||||||
|
appearance change. This opens up exciting prospects for deploying deep matchers in latency-sensitive applications like
|
||||||
|
3D reconstruction. The code and trained models are publicly available at this [https URL](https://github.com/cvg/LightGlue)*
|
||||||
|
|
||||||
|
## How to use
|
||||||
|
|
||||||
|
Here is a quick example of using the model. Since this model is an image matching model, it requires pairs of images to be matched.
|
||||||
|
The raw outputs contain the list of keypoints detected by the keypoint detector as well as the list of matches with their corresponding
|
||||||
|
matching scores.
|
||||||
|
```python
|
||||||
|
from transformers import AutoImageProcessor, AutoModel
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
|
||||||
|
url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
|
||||||
|
image1 = Image.open(requests.get(url_image1, stream=True).raw)
|
||||||
|
url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
|
||||||
|
image2 = Image.open(requests.get(url_image2, stream=True).raw)
|
||||||
|
|
||||||
|
images = [image1, image2]
|
||||||
|
|
||||||
|
processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
|
||||||
|
model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint")
|
||||||
|
|
||||||
|
inputs = processor(images, return_tensors="pt")
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use the `post_process_keypoint_matching` method from the `LightGlueImageProcessor` to get the keypoints and matches in a readable format:
|
||||||
|
```python
|
||||||
|
image_sizes = [[(image.height, image.width) for image in images]]
|
||||||
|
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
print("For the image pair", i)
|
||||||
|
for keypoint0, keypoint1, matching_score in zip(
|
||||||
|
output["keypoints0"], output["keypoints1"], output["matching_scores"]
|
||||||
|
):
|
||||||
|
print(
|
||||||
|
f"Keypoint at coordinate {keypoint0.numpy()} in the first image matches with keypoint at coordinate {keypoint1.numpy()} in the second image with a score of {matching_score}."
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can visualize the matches between the images by providing the original images as well as the outputs to this method:
|
||||||
|
```python
|
||||||
|
processor.plot_keypoint_matching(images, outputs)
|
||||||
|
```
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
This model was contributed by [stevenbucaille](https://huggingface.co/stevenbucaille).
|
||||||
|
The original code can be found [here](https://github.com/cvg/LightGlue).
|
||||||
|
|
||||||
|
## LightGlueConfig
|
||||||
|
|
||||||
|
[[autodoc]] LightGlueConfig
|
||||||
|
|
||||||
|
## LightGlueImageProcessor
|
||||||
|
|
||||||
|
[[autodoc]] LightGlueImageProcessor
|
||||||
|
|
||||||
|
- preprocess
|
||||||
|
- post_process_keypoint_matching
|
||||||
|
- plot_keypoint_matching
|
||||||
|
|
||||||
|
## LightGlueForKeypointMatching
|
||||||
|
|
||||||
|
[[autodoc]] LightGlueForKeypointMatching
|
||||||
|
|
||||||
|
- forward
|
||||||
@@ -231,6 +231,7 @@ _import_structure = {
|
|||||||
"is_faiss_available",
|
"is_faiss_available",
|
||||||
"is_flax_available",
|
"is_flax_available",
|
||||||
"is_keras_nlp_available",
|
"is_keras_nlp_available",
|
||||||
|
"is_matplotlib_available",
|
||||||
"is_phonemizer_available",
|
"is_phonemizer_available",
|
||||||
"is_psutil_available",
|
"is_psutil_available",
|
||||||
"is_py3nvml_available",
|
"is_py3nvml_available",
|
||||||
@@ -728,6 +729,7 @@ if TYPE_CHECKING:
|
|||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_keras_nlp_available,
|
is_keras_nlp_available,
|
||||||
|
is_matplotlib_available,
|
||||||
is_phonemizer_available,
|
is_phonemizer_available,
|
||||||
is_psutil_available,
|
is_psutil_available,
|
||||||
is_py3nvml_available,
|
is_py3nvml_available,
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ if TYPE_CHECKING:
|
|||||||
from .layoutxlm import *
|
from .layoutxlm import *
|
||||||
from .led import *
|
from .led import *
|
||||||
from .levit import *
|
from .levit import *
|
||||||
|
from .lightglue import *
|
||||||
from .lilt import *
|
from .lilt import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
from .llama4 import *
|
from .llama4 import *
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
|
|||||||
("layoutlmv3", "LayoutLMv3Config"),
|
("layoutlmv3", "LayoutLMv3Config"),
|
||||||
("led", "LEDConfig"),
|
("led", "LEDConfig"),
|
||||||
("levit", "LevitConfig"),
|
("levit", "LevitConfig"),
|
||||||
|
("lightglue", "LightGlueConfig"),
|
||||||
("lilt", "LiltConfig"),
|
("lilt", "LiltConfig"),
|
||||||
("llama", "LlamaConfig"),
|
("llama", "LlamaConfig"),
|
||||||
("llama4", "Llama4Config"),
|
("llama4", "Llama4Config"),
|
||||||
@@ -556,6 +557,7 @@ MODEL_NAMES_MAPPING = OrderedDict[str, str](
|
|||||||
("layoutxlm", "LayoutXLM"),
|
("layoutxlm", "LayoutXLM"),
|
||||||
("led", "LED"),
|
("led", "LED"),
|
||||||
("levit", "LeViT"),
|
("levit", "LeViT"),
|
||||||
|
("lightglue", "LightGlue"),
|
||||||
("lilt", "LiLT"),
|
("lilt", "LiLT"),
|
||||||
("llama", "LLaMA"),
|
("llama", "LLaMA"),
|
||||||
("llama2", "Llama2"),
|
("llama2", "Llama2"),
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ else:
|
|||||||
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
|
||||||
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
|
||||||
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
|
||||||
|
("lightglue", ("LightGlueImageProcessor",)),
|
||||||
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
|
||||||
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
|
||||||
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
|
||||||
|
|||||||
@@ -175,6 +175,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
|
|||||||
("layoutlmv3", "LayoutLMv3Model"),
|
("layoutlmv3", "LayoutLMv3Model"),
|
||||||
("led", "LEDModel"),
|
("led", "LEDModel"),
|
||||||
("levit", "LevitModel"),
|
("levit", "LevitModel"),
|
||||||
|
("lightglue", "LightGlueForKeypointMatching"),
|
||||||
("lilt", "LiltModel"),
|
("lilt", "LiltModel"),
|
||||||
("llama", "LlamaModel"),
|
("llama", "LlamaModel"),
|
||||||
("llama4", "Llama4ForConditionalGeneration"),
|
("llama4", "Llama4ForConditionalGeneration"),
|
||||||
|
|||||||
28
src/transformers/models/lightglue/__init__.py
Normal file
28
src/transformers/models/lightglue/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import _LazyModule
|
||||||
|
from ...utils.import_utils import define_import_structure
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .configuration_lightglue import *
|
||||||
|
from .image_processing_lightglue import *
|
||||||
|
from .modeling_lightglue import *
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
_file = globals()["__file__"]
|
||||||
|
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
|
||||||
143
src/transformers/models/lightglue/configuration_lightglue.py
Normal file
143
src/transformers/models/lightglue/configuration_lightglue.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_lightglue.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ..auto import CONFIG_MAPPING, AutoConfig
|
||||||
|
from ..superpoint import SuperPointConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to
|
||||||
|
instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
|
configuration with the defaults will yield a similar configuration to that of the LightGlue
|
||||||
|
[ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
|
||||||
|
The config object or dictionary of the keypoint detector.
|
||||||
|
descriptor_dim (`int`, *optional*, defaults to 256):
|
||||||
|
The dimension of the descriptors.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 9):
|
||||||
|
The number of self and cross attention layers.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 4):
|
||||||
|
The number of heads in the multi-head attention.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
depth_confidence (`float`, *optional*, defaults to 0.95):
|
||||||
|
The confidence threshold used to perform early stopping
|
||||||
|
width_confidence (`float`, *optional*, defaults to 0.99):
|
||||||
|
The confidence threshold used to prune points
|
||||||
|
filter_threshold (`float`, *optional*, defaults to 0.1):
|
||||||
|
The confidence threshold used to filter matches
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
hidden_act (`str`, *optional*, defaults to `"gelu"`):
|
||||||
|
The activation function to be used in the hidden layers.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
attention_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from transformers import LightGlueConfig, LightGlueForKeypointMatching
|
||||||
|
|
||||||
|
>>> # Initializing a LightGlue style configuration
|
||||||
|
>>> configuration = LightGlueConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the LightGlue style configuration
|
||||||
|
>>> model = LightGlueForKeypointMatching(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "lightglue"
|
||||||
|
sub_configs = {"keypoint_detector_config": AutoConfig}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
keypoint_detector_config: SuperPointConfig = None,
|
||||||
|
descriptor_dim: int = 256,
|
||||||
|
num_hidden_layers: int = 9,
|
||||||
|
num_attention_heads: int = 4,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
depth_confidence: float = 0.95,
|
||||||
|
width_confidence: float = 0.99,
|
||||||
|
filter_threshold: float = 0.1,
|
||||||
|
initializer_range: float = 0.02,
|
||||||
|
hidden_act: str = "gelu",
|
||||||
|
attention_dropout=0.0,
|
||||||
|
attention_bias=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if descriptor_dim % num_attention_heads != 0:
|
||||||
|
raise ValueError("descriptor_dim % num_heads is different from zero")
|
||||||
|
|
||||||
|
self.descriptor_dim = descriptor_dim
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
|
||||||
|
self.depth_confidence = depth_confidence
|
||||||
|
self.width_confidence = width_confidence
|
||||||
|
self.filter_threshold = filter_threshold
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
|
||||||
|
# Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention
|
||||||
|
# See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153
|
||||||
|
if isinstance(keypoint_detector_config, dict):
|
||||||
|
keypoint_detector_config["model_type"] = (
|
||||||
|
keypoint_detector_config["model_type"] if "model_type" in keypoint_detector_config else "superpoint"
|
||||||
|
)
|
||||||
|
keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]](
|
||||||
|
**keypoint_detector_config, attn_implementation="eager"
|
||||||
|
)
|
||||||
|
if keypoint_detector_config is None:
|
||||||
|
keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
|
||||||
|
|
||||||
|
self.keypoint_detector_config = keypoint_detector_config
|
||||||
|
|
||||||
|
self.hidden_size = descriptor_dim
|
||||||
|
self.intermediate_size = descriptor_dim * 2
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["LightGlueConfig"]
|
||||||
281
src/transformers/models/lightglue/convert_lightglue_to_hf.py
Normal file
281
src/transformers/models/lightglue/convert_lightglue_to_hf.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForKeypointDetection,
|
||||||
|
LightGlueForKeypointMatching,
|
||||||
|
LightGlueImageProcessor,
|
||||||
|
)
|
||||||
|
from transformers.models.lightglue.configuration_lightglue import LightGlueConfig
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CHECKPOINT_URL = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_lightglue.pth"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_imgs():
|
||||||
|
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
|
||||||
|
image0 = dataset[0]["image"]
|
||||||
|
image1 = dataset[1]["image"]
|
||||||
|
image2 = dataset[2]["image"]
|
||||||
|
# [image1, image1] on purpose to test the model early stopping
|
||||||
|
return [[image2, image0], [image1, image1]]
|
||||||
|
|
||||||
|
|
||||||
|
def verify_model_outputs(model, device):
|
||||||
|
images = prepare_imgs()
|
||||||
|
preprocessor = LightGlueImageProcessor()
|
||||||
|
inputs = preprocessor(images=images, return_tensors="pt").to(device)
|
||||||
|
model.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
||||||
|
|
||||||
|
predicted_matches_values = outputs.matches[0, 0, 20:30]
|
||||||
|
predicted_matching_scores_values = outputs.matching_scores[0, 0, 20:30]
|
||||||
|
|
||||||
|
predicted_number_of_matches = torch.sum(outputs.matches[0][0] != -1).item()
|
||||||
|
|
||||||
|
expected_max_number_keypoints = 866
|
||||||
|
expected_matches_shape = torch.Size((len(images), 2, expected_max_number_keypoints))
|
||||||
|
expected_matching_scores_shape = torch.Size((len(images), 2, expected_max_number_keypoints))
|
||||||
|
|
||||||
|
expected_matches_values = torch.tensor([-1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64).to(device)
|
||||||
|
expected_matching_scores_values = torch.tensor([0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583]).to(device)
|
||||||
|
|
||||||
|
expected_number_of_matches = 140
|
||||||
|
|
||||||
|
assert outputs.matches.shape == expected_matches_shape
|
||||||
|
assert outputs.matching_scores.shape == expected_matching_scores_shape
|
||||||
|
|
||||||
|
assert torch.allclose(predicted_matches_values, expected_matches_values, atol=1e-2)
|
||||||
|
assert torch.allclose(predicted_matching_scores_values, expected_matching_scores_values, atol=1e-2)
|
||||||
|
|
||||||
|
assert predicted_number_of_matches == expected_number_of_matches
|
||||||
|
|
||||||
|
|
||||||
|
ORIGINAL_TO_CONVERTED_KEY_MAPPING = {
|
||||||
|
r"posenc.Wr": r"positional_encoder.projector",
|
||||||
|
r"self_attn.(\d+).Wqkv": r"transformer_layers.\1.self_attention.Wqkv",
|
||||||
|
r"self_attn.(\d+).out_proj": r"transformer_layers.\1.self_attention.o_proj",
|
||||||
|
r"self_attn.(\d+).ffn.0": r"transformer_layers.\1.self_mlp.fc1",
|
||||||
|
r"self_attn.(\d+).ffn.1": r"transformer_layers.\1.self_mlp.layer_norm",
|
||||||
|
r"self_attn.(\d+).ffn.3": r"transformer_layers.\1.self_mlp.fc2",
|
||||||
|
r"cross_attn.(\d+).to_qk": r"transformer_layers.\1.cross_attention.to_qk",
|
||||||
|
r"cross_attn.(\d+).to_v": r"transformer_layers.\1.cross_attention.v_proj",
|
||||||
|
r"cross_attn.(\d+).to_out": r"transformer_layers.\1.cross_attention.o_proj",
|
||||||
|
r"cross_attn.(\d+).ffn.0": r"transformer_layers.\1.cross_mlp.fc1",
|
||||||
|
r"cross_attn.(\d+).ffn.1": r"transformer_layers.\1.cross_mlp.layer_norm",
|
||||||
|
r"cross_attn.(\d+).ffn.3": r"transformer_layers.\1.cross_mlp.fc2",
|
||||||
|
r"log_assignment.(\d+).matchability": r"match_assignment_layers.\1.matchability",
|
||||||
|
r"log_assignment.(\d+).final_proj": r"match_assignment_layers.\1.final_projection",
|
||||||
|
r"token_confidence.(\d+).token.0": r"token_confidence.\1.token",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def convert_old_keys_to_new_keys(state_dict_keys: List[str]):
|
||||||
|
"""
|
||||||
|
This function should be applied only once, on the concatenated keys to efficiently rename using
|
||||||
|
the key mappings.
|
||||||
|
"""
|
||||||
|
output_dict = {}
|
||||||
|
if state_dict_keys is not None:
|
||||||
|
old_text = "\n".join(state_dict_keys)
|
||||||
|
new_text = old_text
|
||||||
|
for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items():
|
||||||
|
if replacement is None:
|
||||||
|
new_text = re.sub(pattern, "", new_text) # an empty line
|
||||||
|
continue
|
||||||
|
new_text = re.sub(pattern, replacement, new_text)
|
||||||
|
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
|
def add_keypoint_detector_state_dict(lightglue_state_dict):
|
||||||
|
keypoint_detector = AutoModelForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
|
||||||
|
keypoint_detector_state_dict = keypoint_detector.state_dict()
|
||||||
|
for k, v in keypoint_detector_state_dict.items():
|
||||||
|
lightglue_state_dict[f"keypoint_detector.{k}"] = v
|
||||||
|
return lightglue_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def split_weights(state_dict):
|
||||||
|
for i in range(9):
|
||||||
|
# Remove unused r values
|
||||||
|
log_assignment_r_key = f"log_assignment.{i}.r"
|
||||||
|
if state_dict.get(log_assignment_r_key, None) is not None:
|
||||||
|
state_dict.pop(log_assignment_r_key)
|
||||||
|
|
||||||
|
Wqkv_weight = state_dict.pop(f"transformer_layers.{i}.self_attention.Wqkv.weight")
|
||||||
|
Wqkv_bias = state_dict.pop(f"transformer_layers.{i}.self_attention.Wqkv.bias")
|
||||||
|
Wqkv_weight = Wqkv_weight.reshape(256, 3, 256)
|
||||||
|
Wqkv_bias = Wqkv_bias.reshape(256, 3)
|
||||||
|
query_weight, key_weight, value_weight = Wqkv_weight[:, 0], Wqkv_weight[:, 1], Wqkv_weight[:, 2]
|
||||||
|
query_bias, key_bias, value_bias = Wqkv_bias[:, 0], Wqkv_bias[:, 1], Wqkv_bias[:, 2]
|
||||||
|
state_dict[f"transformer_layers.{i}.self_attention.q_proj.weight"] = query_weight
|
||||||
|
state_dict[f"transformer_layers.{i}.self_attention.k_proj.weight"] = key_weight
|
||||||
|
state_dict[f"transformer_layers.{i}.self_attention.v_proj.weight"] = value_weight
|
||||||
|
state_dict[f"transformer_layers.{i}.self_attention.q_proj.bias"] = query_bias
|
||||||
|
state_dict[f"transformer_layers.{i}.self_attention.k_proj.bias"] = key_bias
|
||||||
|
state_dict[f"transformer_layers.{i}.self_attention.v_proj.bias"] = value_bias
|
||||||
|
|
||||||
|
to_qk_weight = state_dict.pop(f"transformer_layers.{i}.cross_attention.to_qk.weight")
|
||||||
|
to_qk_bias = state_dict.pop(f"transformer_layers.{i}.cross_attention.to_qk.bias")
|
||||||
|
state_dict[f"transformer_layers.{i}.cross_attention.q_proj.weight"] = to_qk_weight
|
||||||
|
state_dict[f"transformer_layers.{i}.cross_attention.q_proj.bias"] = to_qk_bias
|
||||||
|
state_dict[f"transformer_layers.{i}.cross_attention.k_proj.weight"] = to_qk_weight
|
||||||
|
state_dict[f"transformer_layers.{i}.cross_attention.k_proj.bias"] = to_qk_bias
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def write_model(
|
||||||
|
model_path,
|
||||||
|
checkpoint_url,
|
||||||
|
organization,
|
||||||
|
safe_serialization=True,
|
||||||
|
push_to_hub=False,
|
||||||
|
):
|
||||||
|
os.makedirs(model_path, exist_ok=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# LightGlue config
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
config = LightGlueConfig(
|
||||||
|
descriptor_dim=256,
|
||||||
|
num_hidden_layers=9,
|
||||||
|
num_attention_heads=4,
|
||||||
|
)
|
||||||
|
config.architectures = ["LightGlueForKeypointMatching"]
|
||||||
|
config.save_pretrained(model_path)
|
||||||
|
print("Model config saved successfully...")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Convert weights
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
|
print(f"Fetching all parameters from the checkpoint at {checkpoint_url}...")
|
||||||
|
original_state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
|
||||||
|
|
||||||
|
print("Converting model...")
|
||||||
|
all_keys = list(original_state_dict.keys())
|
||||||
|
new_keys = convert_old_keys_to_new_keys(all_keys)
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
for key in all_keys:
|
||||||
|
new_key = new_keys[key]
|
||||||
|
state_dict[new_key] = original_state_dict.pop(key).contiguous().clone()
|
||||||
|
|
||||||
|
del original_state_dict
|
||||||
|
gc.collect()
|
||||||
|
state_dict = split_weights(state_dict)
|
||||||
|
state_dict = add_keypoint_detector_state_dict(state_dict)
|
||||||
|
|
||||||
|
print("Loading the checkpoint in a LightGlue model...")
|
||||||
|
device = "cuda"
|
||||||
|
with torch.device(device):
|
||||||
|
model = LightGlueForKeypointMatching(config)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
print("Checkpoint loaded successfully...")
|
||||||
|
del model.config._name_or_path
|
||||||
|
|
||||||
|
print("Saving the model...")
|
||||||
|
model.save_pretrained(model_path, safe_serialization=safe_serialization)
|
||||||
|
del state_dict, model
|
||||||
|
|
||||||
|
# Safety check: reload the converted model
|
||||||
|
gc.collect()
|
||||||
|
print("Reloading the model to check if it's saved correctly.")
|
||||||
|
model = LightGlueForKeypointMatching.from_pretrained(model_path)
|
||||||
|
print("Model reloaded successfully.")
|
||||||
|
|
||||||
|
model_name = "lightglue"
|
||||||
|
if "superpoint" in checkpoint_url:
|
||||||
|
model_name += "_superpoint"
|
||||||
|
if checkpoint_url == DEFAULT_CHECKPOINT_URL:
|
||||||
|
print("Checking the model outputs...")
|
||||||
|
verify_model_outputs(model, device)
|
||||||
|
print("Model outputs verified successfully.")
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
print("Pushing model to the hub...")
|
||||||
|
model.push_to_hub(
|
||||||
|
repo_id=f"{organization}/{model_name}",
|
||||||
|
commit_message="Add model",
|
||||||
|
)
|
||||||
|
config.push_to_hub(repo_id=f"{organization}/{model_name}", commit_message="Add config")
|
||||||
|
|
||||||
|
write_image_processor(model_path, model_name, organization, push_to_hub=push_to_hub)
|
||||||
|
|
||||||
|
|
||||||
|
def write_image_processor(save_dir, model_name, organization, push_to_hub=False):
|
||||||
|
if "superpoint" in model_name:
|
||||||
|
image_processor = LightGlueImageProcessor(do_grayscale=True)
|
||||||
|
else:
|
||||||
|
image_processor = LightGlueImageProcessor()
|
||||||
|
image_processor.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
print("Pushing image processor to the hub...")
|
||||||
|
image_processor.push_to_hub(
|
||||||
|
repo_id=f"{organization}/{model_name}",
|
||||||
|
commit_message="Add image processor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# Required parameters
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint_url",
|
||||||
|
default=DEFAULT_CHECKPOINT_URL,
|
||||||
|
type=str,
|
||||||
|
help="URL of the original LightGlue checkpoint you'd like to convert.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pytorch_dump_folder_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the output PyTorch model directory.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--save_model", action="store_true", help="Save model to local")
|
||||||
|
parser.add_argument(
|
||||||
|
"--push_to_hub",
|
||||||
|
action="store_true",
|
||||||
|
help="Push model and image preprocessor to the hub",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--organization",
|
||||||
|
default="ETH-CVG",
|
||||||
|
type=str,
|
||||||
|
help="Hub organization in which you want the model to be uploaded.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
write_model(
|
||||||
|
args.pytorch_dump_folder_path,
|
||||||
|
args.checkpoint_url,
|
||||||
|
args.organization,
|
||||||
|
safe_serialization=True,
|
||||||
|
push_to_hub=args.push_to_hub,
|
||||||
|
)
|
||||||
452
src/transformers/models/lightglue/image_processing_lightglue.py
Normal file
452
src/transformers/models/lightglue/image_processing_lightglue.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_lightglue.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
|
from ...image_transforms import resize, to_channel_dimension_format
|
||||||
|
from ...image_utils import (
|
||||||
|
ChannelDimension,
|
||||||
|
ImageInput,
|
||||||
|
ImageType,
|
||||||
|
PILImageResampling,
|
||||||
|
get_image_type,
|
||||||
|
infer_channel_dimension_format,
|
||||||
|
is_pil_image,
|
||||||
|
is_scaled_image,
|
||||||
|
is_valid_image,
|
||||||
|
is_vision_available,
|
||||||
|
to_numpy_array,
|
||||||
|
valid_images,
|
||||||
|
validate_preprocess_arguments,
|
||||||
|
)
|
||||||
|
from ...utils import TensorType, is_matplotlib_available, logging, requires_backends
|
||||||
|
from ...utils.import_utils import requires
|
||||||
|
from .modeling_lightglue import LightGlueKeypointMatchingOutput
|
||||||
|
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_grayscale(
|
||||||
|
image: ImageInput,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
):
|
||||||
|
if input_data_format == ChannelDimension.FIRST:
|
||||||
|
if image.shape[0] == 1:
|
||||||
|
return True
|
||||||
|
return np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])
|
||||||
|
elif input_data_format == ChannelDimension.LAST:
|
||||||
|
if image.shape[-1] == 1:
|
||||||
|
return True
|
||||||
|
return np.all(image[..., 0] == image[..., 1]) and np.all(image[..., 1] == image[..., 2])
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_grayscale(
|
||||||
|
image: ImageInput,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
) -> ImageInput:
|
||||||
|
"""
|
||||||
|
Converts an image to grayscale format using the NTSC formula. Only support numpy and PIL Image. TODO support torch
|
||||||
|
and tensorflow grayscale conversion
|
||||||
|
|
||||||
|
This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each
|
||||||
|
channel, because of an issue that is discussed in :
|
||||||
|
https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image):
|
||||||
|
The image to convert.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image.
|
||||||
|
"""
|
||||||
|
requires_backends(convert_to_grayscale, ["vision"])
|
||||||
|
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
if is_grayscale(image, input_data_format=input_data_format):
|
||||||
|
return image
|
||||||
|
if input_data_format == ChannelDimension.FIRST:
|
||||||
|
gray_image = image[0, ...] * 0.2989 + image[1, ...] * 0.5870 + image[2, ...] * 0.1140
|
||||||
|
gray_image = np.stack([gray_image] * 3, axis=0)
|
||||||
|
elif input_data_format == ChannelDimension.LAST:
|
||||||
|
gray_image = image[..., 0] * 0.2989 + image[..., 1] * 0.5870 + image[..., 2] * 0.1140
|
||||||
|
gray_image = np.stack([gray_image] * 3, axis=-1)
|
||||||
|
return gray_image
|
||||||
|
|
||||||
|
if not isinstance(image, PIL.Image.Image):
|
||||||
|
return image
|
||||||
|
|
||||||
|
image = image.convert("L")
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def validate_and_format_image_pairs(images: ImageInput):
|
||||||
|
error_message = (
|
||||||
|
"Input images must be a one of the following :",
|
||||||
|
" - A pair of PIL images.",
|
||||||
|
" - A pair of 3D arrays.",
|
||||||
|
" - A list of pairs of PIL images.",
|
||||||
|
" - A list of pairs of 3D arrays.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_valid_image(image):
|
||||||
|
"""images is a PIL Image or a 3D array."""
|
||||||
|
return is_pil_image(image) or (
|
||||||
|
is_valid_image(image) and get_image_type(image) != ImageType.PIL and len(image.shape) == 3
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(images, list):
|
||||||
|
if len(images) == 2 and all((_is_valid_image(image)) for image in images):
|
||||||
|
return images
|
||||||
|
if all(
|
||||||
|
isinstance(image_pair, list)
|
||||||
|
and len(image_pair) == 2
|
||||||
|
and all(_is_valid_image(image) for image in image_pair)
|
||||||
|
for image_pair in images
|
||||||
|
):
|
||||||
|
return [image for image_pair in images for image in image_pair]
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
||||||
|
@requires(backends=("torch",))
|
||||||
|
class LightGlueImageProcessor(BaseImageProcessor):
|
||||||
|
r"""
|
||||||
|
Constructs a LightGlue image processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
do_resize (`bool`, *optional*, defaults to `True`):
|
||||||
|
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden
|
||||||
|
by `do_resize` in the `preprocess` method.
|
||||||
|
size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
|
||||||
|
Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to
|
||||||
|
`True`. Can be overridden by `size` in the `preprocess` method.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
|
||||||
|
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
|
||||||
|
the `preprocess` method.
|
||||||
|
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
|
||||||
|
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
|
||||||
|
method.
|
||||||
|
do_grayscale (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_input_names = ["pixel_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
do_resize: bool = True,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = PILImageResampling.BILINEAR,
|
||||||
|
do_rescale: bool = True,
|
||||||
|
rescale_factor: float = 1 / 255,
|
||||||
|
do_grayscale: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
size = size if size is not None else {"height": 480, "width": 640}
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
self.do_resize = do_resize
|
||||||
|
self.size = size
|
||||||
|
self.resample = resample
|
||||||
|
self.do_rescale = do_rescale
|
||||||
|
self.rescale_factor = rescale_factor
|
||||||
|
self.do_grayscale = do_grayscale
|
||||||
|
|
||||||
|
def resize(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
size: Dict[str, int],
|
||||||
|
data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Resize an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (`np.ndarray`):
|
||||||
|
Image to resize.
|
||||||
|
size (`Dict[str, int]`):
|
||||||
|
Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format of the output image. If not provided, it will be inferred from the input
|
||||||
|
image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||||
|
from the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
"""
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
return resize(
|
||||||
|
image,
|
||||||
|
size=(size["height"], size["width"]),
|
||||||
|
data_format=data_format,
|
||||||
|
input_data_format=input_data_format,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def preprocess(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
do_resize: Optional[bool] = None,
|
||||||
|
size: Optional[Dict[str, int]] = None,
|
||||||
|
resample: PILImageResampling = None,
|
||||||
|
do_rescale: Optional[bool] = None,
|
||||||
|
rescale_factor: Optional[float] = None,
|
||||||
|
do_grayscale: Optional[bool] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
data_format: ChannelDimension = ChannelDimension.FIRST,
|
||||||
|
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> BatchFeature:
|
||||||
|
"""
|
||||||
|
Preprocess an image or batch of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image pairs to preprocess. Expects either a list of 2 images or a list of list of 2 images list with
|
||||||
|
pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, set
|
||||||
|
`do_rescale=False`.
|
||||||
|
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
|
||||||
|
Whether to resize the image.
|
||||||
|
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
|
||||||
|
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
|
||||||
|
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
|
||||||
|
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
|
||||||
|
`(size["shortest_edge"], size["shortest_edge"])`. Only has an effect if `do_resize` is set to `True`.
|
||||||
|
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
|
||||||
|
Resampling filter to use if resizing the image. This can be one of `PILImageResampling`, filters. Only
|
||||||
|
has an effect if `do_resize` is set to `True`.
|
||||||
|
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
|
||||||
|
Whether to rescale the image values between [0 - 1].
|
||||||
|
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
|
||||||
|
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
|
||||||
|
do_grayscale (`bool`, *optional*, defaults to `self.do_grayscale`):
|
||||||
|
Whether to convert the image to grayscale.
|
||||||
|
return_tensors (`str` or `TensorType`, *optional*):
|
||||||
|
The type of tensors to return. Can be one of:
|
||||||
|
- Unset: Return a list of `np.ndarray`.
|
||||||
|
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
|
||||||
|
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
|
||||||
|
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
|
||||||
|
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
|
||||||
|
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
|
||||||
|
The channel dimension format for the output image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- Unset: Use the channel dimension format of the input image.
|
||||||
|
input_data_format (`ChannelDimension` or `str`, *optional*):
|
||||||
|
The channel dimension format for the input image. If unset, the channel dimension format is inferred
|
||||||
|
from the input image. Can be one of:
|
||||||
|
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
|
||||||
|
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
|
||||||
|
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
do_resize = do_resize if do_resize is not None else self.do_resize
|
||||||
|
resample = resample if resample is not None else self.resample
|
||||||
|
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
|
||||||
|
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
|
||||||
|
do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale
|
||||||
|
|
||||||
|
size = size if size is not None else self.size
|
||||||
|
size = get_size_dict(size, default_to_square=False)
|
||||||
|
|
||||||
|
# Validate and convert the input images into a flattened list of images for all subsequent processing steps.
|
||||||
|
images = validate_and_format_image_pairs(images)
|
||||||
|
|
||||||
|
if not valid_images(images):
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
|
||||||
|
"torch.Tensor, tf.Tensor or jax.ndarray."
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_preprocess_arguments(
|
||||||
|
do_resize=do_resize,
|
||||||
|
size=size,
|
||||||
|
resample=resample,
|
||||||
|
do_rescale=do_rescale,
|
||||||
|
rescale_factor=rescale_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All transformations expect numpy arrays.
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
|
||||||
|
if is_scaled_image(images[0]) and do_rescale:
|
||||||
|
logger.warning_once(
|
||||||
|
"It looks like you are trying to rescale already rescaled images. If the input"
|
||||||
|
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
|
||||||
|
)
|
||||||
|
|
||||||
|
if input_data_format is None:
|
||||||
|
# We assume that all images have the same channel dimension format.
|
||||||
|
input_data_format = infer_channel_dimension_format(images[0])
|
||||||
|
|
||||||
|
all_images = []
|
||||||
|
for image in images:
|
||||||
|
if do_resize:
|
||||||
|
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
if do_rescale:
|
||||||
|
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
if do_grayscale:
|
||||||
|
image = convert_to_grayscale(image, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
|
||||||
|
all_images.append(image)
|
||||||
|
|
||||||
|
# Convert back the flattened list of images into a list of pairs of images.
|
||||||
|
image_pairs = [all_images[i : i + 2] for i in range(0, len(all_images), 2)]
|
||||||
|
|
||||||
|
data = {"pixel_values": image_pairs}
|
||||||
|
|
||||||
|
return BatchFeature(data=data, tensor_type=return_tensors)
|
||||||
|
|
||||||
|
def post_process_keypoint_matching(
|
||||||
|
self,
|
||||||
|
outputs: LightGlueKeypointMatchingOutput,
|
||||||
|
target_sizes: Union[TensorType, List[Tuple]],
|
||||||
|
threshold: float = 0.0,
|
||||||
|
) -> List[Dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
|
||||||
|
with coordinates absolute to the original image sizes.
|
||||||
|
Args:
|
||||||
|
outputs ([`KeypointMatchingOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
|
||||||
|
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
|
||||||
|
target size `(height, width)` of each image in the batch. This must be the original image size (before
|
||||||
|
any processing).
|
||||||
|
threshold (`float`, *optional*, defaults to 0.0):
|
||||||
|
Threshold to filter out the matches with low scores.
|
||||||
|
Returns:
|
||||||
|
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
|
||||||
|
of the pair, the matching scores and the matching indices.
|
||||||
|
"""
|
||||||
|
if outputs.mask.shape[0] != len(target_sizes):
|
||||||
|
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
|
||||||
|
if not all(len(target_size) == 2 for target_size in target_sizes):
|
||||||
|
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
|
||||||
|
|
||||||
|
if isinstance(target_sizes, List):
|
||||||
|
image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
|
||||||
|
else:
|
||||||
|
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
|
||||||
|
)
|
||||||
|
image_pair_sizes = target_sizes
|
||||||
|
|
||||||
|
keypoints = outputs.keypoints.clone()
|
||||||
|
keypoints = keypoints * image_pair_sizes.flip(-1).reshape(-1, 2, 1, 2)
|
||||||
|
keypoints = keypoints.to(torch.int32)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for mask_pair, keypoints_pair, matches, scores in zip(
|
||||||
|
outputs.mask, keypoints, outputs.matches[:, 0], outputs.matching_scores[:, 0]
|
||||||
|
):
|
||||||
|
mask0 = mask_pair[0] > 0
|
||||||
|
mask1 = mask_pair[1] > 0
|
||||||
|
keypoints0 = keypoints_pair[0][mask0]
|
||||||
|
keypoints1 = keypoints_pair[1][mask1]
|
||||||
|
matches0 = matches[mask0]
|
||||||
|
scores0 = scores[mask0]
|
||||||
|
|
||||||
|
# Filter out matches with low scores
|
||||||
|
valid_matches = torch.logical_and(scores0 > threshold, matches0 > -1)
|
||||||
|
|
||||||
|
matched_keypoints0 = keypoints0[valid_matches]
|
||||||
|
matched_keypoints1 = keypoints1[matches0[valid_matches]]
|
||||||
|
matching_scores = scores0[valid_matches]
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"keypoints0": matched_keypoints0,
|
||||||
|
"keypoints1": matched_keypoints1,
|
||||||
|
"matching_scores": matching_scores,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
|
||||||
|
"""
|
||||||
|
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
|
||||||
|
matplotlib to be installed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images (`ImageInput`):
|
||||||
|
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
|
||||||
|
a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||||
|
outputs ([`LightGlueKeypointMatchingOutput`]):
|
||||||
|
Raw outputs of the model.
|
||||||
|
"""
|
||||||
|
if is_matplotlib_available():
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
else:
|
||||||
|
raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
|
||||||
|
|
||||||
|
images = validate_and_format_image_pairs(images)
|
||||||
|
images = [to_numpy_array(image) for image in images]
|
||||||
|
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||||
|
|
||||||
|
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||||
|
height0, width0 = image_pair[0].shape[:2]
|
||||||
|
height1, width1 = image_pair[1].shape[:2]
|
||||||
|
plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
|
||||||
|
plot_image[:height0, :width0] = image_pair[0] / 255.0
|
||||||
|
plot_image[:height1, width0:] = image_pair[1] / 255.0
|
||||||
|
plt.imshow(plot_image)
|
||||||
|
plt.axis("off")
|
||||||
|
|
||||||
|
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||||
|
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||||
|
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||||
|
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||||
|
):
|
||||||
|
plt.plot(
|
||||||
|
[keypoint0_x, keypoint1_x + width0],
|
||||||
|
[keypoint0_y, keypoint1_y],
|
||||||
|
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
||||||
|
alpha=0.9,
|
||||||
|
linewidth=0.5,
|
||||||
|
)
|
||||||
|
plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
|
||||||
|
plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["LightGlueImageProcessor"]
|
||||||
926
src/transformers/models/lightglue/modeling_lightglue.py
Normal file
926
src/transformers/models/lightglue/modeling_lightglue.py
Normal file
@@ -0,0 +1,926 @@
|
|||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
|
||||||
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||||
|
# the file from the modular. If any change should be done, please apply the change to the
|
||||||
|
# modular_lightglue.py file directly. One of our CI enforces this.
|
||||||
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
from ...activations import ACT2FN
|
||||||
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
|
from ...processing_utils import Unpack
|
||||||
|
from ...utils import ModelOutput, auto_docstring
|
||||||
|
from ...utils.generic import can_return_tuple
|
||||||
|
from ..auto.modeling_auto import AutoModelForKeypointDetection
|
||||||
|
from .configuration_lightglue import LightGlueConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LightGlueKeypointMatchingOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
|
||||||
|
the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
|
||||||
|
batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
|
||||||
|
tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
|
||||||
|
matching information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
||||||
|
Loss computed during training.
|
||||||
|
matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
|
||||||
|
Index of keypoint matched in the other image.
|
||||||
|
matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
|
||||||
|
Scores of predicted matches.
|
||||||
|
keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
|
||||||
|
Absolute (x, y) coordinates of predicted keypoints in a given image.
|
||||||
|
prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
|
||||||
|
Pruning mask indicating which keypoints are removed and at which layer.
|
||||||
|
mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
|
||||||
|
Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
|
||||||
|
information.
|
||||||
|
hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
|
||||||
|
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
|
||||||
|
num_keypoints)` returned when `output_hidden_states=True` is passed or when
|
||||||
|
`config.output_hidden_states=True`
|
||||||
|
attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
|
||||||
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
|
||||||
|
num_keypoints)` returned when `output_attentions=True` is passed or when
|
||||||
|
`config.output_attentions=True`
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
matches: Optional[torch.FloatTensor] = None
|
||||||
|
matching_scores: Optional[torch.FloatTensor] = None
|
||||||
|
keypoints: Optional[torch.FloatTensor] = None
|
||||||
|
prune: Optional[torch.IntTensor] = None
|
||||||
|
mask: Optional[torch.FloatTensor] = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LightGluePositionalEncoder(nn.Module):
|
||||||
|
def __init__(self, config: LightGlueConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
|
||||||
|
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
projected_keypoints = self.projector(keypoints)
|
||||||
|
embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
|
||||||
|
cosines = torch.cos(embeddings)
|
||||||
|
sines = torch.sin(embeddings)
|
||||||
|
embeddings = (cosines, sines)
|
||||||
|
output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
# Split and rotate. Note that this function is different from e.g. Llama.
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
||||||
|
return rot_x
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q (`torch.Tensor`): The query tensor.
|
||||||
|
k (`torch.Tensor`): The key tensor.
|
||||||
|
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||||
|
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||||
|
position_ids (`torch.Tensor`, *optional*):
|
||||||
|
Deprecated and unused.
|
||||||
|
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||||
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||||
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||||
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||||
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||||
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||||
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||||
|
Returns:
|
||||||
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||||
|
"""
|
||||||
|
dtype = q.dtype
|
||||||
|
q = q.float()
|
||||||
|
k = k.float()
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def eager_attention_forward(
|
||||||
|
module: nn.Module,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor],
|
||||||
|
scaling: float,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(self, config: LightGlueConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.is_causal = True
|
||||||
|
|
||||||
|
self.q_proj = nn.Linear(
|
||||||
|
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||||
|
)
|
||||||
|
self.k_proj = nn.Linear(
|
||||||
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||||
|
)
|
||||||
|
self.v_proj = nn.Linear(
|
||||||
|
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||||
|
)
|
||||||
|
self.o_proj = nn.Linear(
|
||||||
|
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: Unpack[FlashAttentionKwargs],
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
input_shape = hidden_states.shape[:-1]
|
||||||
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||||
|
current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||||
|
|
||||||
|
key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
attention_interface: Callable = eager_attention_forward
|
||||||
|
if self.config._attn_implementation != "eager":
|
||||||
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||||
|
|
||||||
|
attn_output, attn_weights = attention_interface(
|
||||||
|
self,
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
current_attention_mask,
|
||||||
|
dropout=0.0 if not self.training else self.attention_dropout,
|
||||||
|
scaling=self.scaling,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueMLP(nn.Module):
|
||||||
|
def __init__(self, config: LightGlueConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
|
||||||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
|
self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueTransformerLayer(nn.Module):
|
||||||
|
def __init__(self, config: LightGlueConfig, layer_idx: int):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attention = LightGlueAttention(config, layer_idx)
|
||||||
|
self.self_mlp = LightGlueMLP(config)
|
||||||
|
self.cross_attention = LightGlueAttention(config, layer_idx)
|
||||||
|
self.cross_mlp = LightGlueMLP(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
descriptors: torch.Tensor,
|
||||||
|
keypoints: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
output_hidden_states: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (descriptors,)
|
||||||
|
|
||||||
|
batch_size, num_keypoints, descriptor_dim = descriptors.shape
|
||||||
|
|
||||||
|
# Self attention block
|
||||||
|
attention_output, self_attentions = self.self_attention(
|
||||||
|
descriptors,
|
||||||
|
position_embeddings=keypoints,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
|
||||||
|
output_states = self.self_mlp(intermediate_states)
|
||||||
|
self_attention_descriptors = descriptors + output_states
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
self_attention_hidden_states = (intermediate_states, output_states)
|
||||||
|
|
||||||
|
# Reshape hidden_states to group by image_pairs :
|
||||||
|
# (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
|
||||||
|
# Flip dimension 1 to perform cross attention :
|
||||||
|
# (image0, image1) -> (image1, image0)
|
||||||
|
# Reshape back to original shape :
|
||||||
|
# (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
|
||||||
|
encoder_hidden_states = (
|
||||||
|
self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
|
||||||
|
.flip(1)
|
||||||
|
.reshape(batch_size, num_keypoints, descriptor_dim)
|
||||||
|
)
|
||||||
|
# Same for mask
|
||||||
|
encoder_attention_mask = (
|
||||||
|
attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
|
||||||
|
if attention_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cross attention block
|
||||||
|
cross_attention_output, cross_attentions = self.cross_attention(
|
||||||
|
self_attention_descriptors,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
|
||||||
|
cross_output_states = self.cross_mlp(cross_intermediate_states)
|
||||||
|
descriptors = self_attention_descriptors + cross_output_states
|
||||||
|
|
||||||
|
if output_hidden_states:
|
||||||
|
cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
|
||||||
|
all_hidden_states = (
|
||||||
|
all_hidden_states
|
||||||
|
+ (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
|
||||||
|
+ self_attention_hidden_states
|
||||||
|
+ (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
|
||||||
|
+ cross_attention_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
|
||||||
|
|
||||||
|
return descriptors, all_hidden_states, all_attentions
|
||||||
|
|
||||||
|
|
||||||
|
def sigmoid_log_double_softmax(
|
||||||
|
similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""create the log assignment matrix from logits and similarity"""
|
||||||
|
batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
|
||||||
|
certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
|
||||||
|
scores0 = nn.functional.log_softmax(similarity, 2)
|
||||||
|
scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
||||||
|
scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
|
||||||
|
scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
|
||||||
|
scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
|
||||||
|
scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueMatchAssignmentLayer(nn.Module):
|
||||||
|
def __init__(self, config: LightGlueConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.descriptor_dim = config.descriptor_dim
|
||||||
|
self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
|
||||||
|
self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
|
||||||
|
|
||||||
|
def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, num_keypoints, descriptor_dim = descriptors.shape
|
||||||
|
# Final projection and similarity computation
|
||||||
|
m_descriptors = self.final_projection(descriptors)
|
||||||
|
m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
|
||||||
|
m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
|
||||||
|
m_descriptors0 = m_descriptors[:, 0]
|
||||||
|
m_descriptors1 = m_descriptors[:, 1]
|
||||||
|
similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask.reshape(batch_size // 2, 2, num_keypoints)
|
||||||
|
mask0 = mask[:, 0].unsqueeze(-1)
|
||||||
|
mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
|
||||||
|
mask = mask0 * mask1
|
||||||
|
similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
|
||||||
|
|
||||||
|
# Compute matchability of descriptors
|
||||||
|
matchability = self.matchability(descriptors)
|
||||||
|
matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
|
||||||
|
matchability_0 = matchability[:, 0]
|
||||||
|
matchability_1 = matchability[:, 1]
|
||||||
|
|
||||||
|
# Compute scores from similarity and matchability
|
||||||
|
scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Get matchability of descriptors as a probability"""
|
||||||
|
matchability = self.matchability(descriptors)
|
||||||
|
matchability = nn.functional.sigmoid(matchability).squeeze(-1)
|
||||||
|
return matchability
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueTokenConfidenceLayer(nn.Module):
|
||||||
|
def __init__(self, config: LightGlueConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.token = nn.Linear(config.descriptor_dim, 1)
|
||||||
|
|
||||||
|
def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
|
||||||
|
token = self.token(descriptors.detach())
|
||||||
|
token = nn.functional.sigmoid(token).squeeze(-1)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring
|
||||||
|
class LightGluePreTrainedModel(PreTrainedModel):
|
||||||
|
"""
|
||||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = LightGlueConfig
|
||||||
|
base_model_prefix = "lightglue"
|
||||||
|
main_input_name = "pixel_values"
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
_supports_flash_attn_2 = True
|
||||||
|
_supports_sdpa = True
|
||||||
|
|
||||||
|
def _init_weights(self, module: nn.Module) -> None:
|
||||||
|
"""Initialize the weights"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""obtain matches from a score matrix [Bx M+1 x N+1]"""
|
||||||
|
batch_size, _, _ = scores.shape
|
||||||
|
# For each keypoint, get the best match
|
||||||
|
max0 = scores[:, :-1, :-1].max(2)
|
||||||
|
max1 = scores[:, :-1, :-1].max(1)
|
||||||
|
matches0 = max0.indices
|
||||||
|
matches1 = max1.indices
|
||||||
|
|
||||||
|
# Mutual check for matches
|
||||||
|
indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
|
||||||
|
indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
|
||||||
|
mutual0 = indices0 == matches1.gather(1, matches0)
|
||||||
|
mutual1 = indices1 == matches0.gather(1, matches1)
|
||||||
|
|
||||||
|
# Get matching scores and filter based on mutual check and thresholding
|
||||||
|
max0 = max0.values.exp()
|
||||||
|
zero = max0.new_tensor(0)
|
||||||
|
matching_scores0 = torch.where(mutual0, max0, zero)
|
||||||
|
matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
|
||||||
|
valid0 = mutual0 & (matching_scores0 > threshold)
|
||||||
|
valid1 = mutual1 & valid0.gather(1, matches1)
|
||||||
|
|
||||||
|
# Filter matches based on mutual check and thresholding of scores
|
||||||
|
matches0 = torch.where(valid0, matches0, -1)
|
||||||
|
matches1 = torch.where(valid1, matches1, -1)
|
||||||
|
matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
|
||||||
|
matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
|
||||||
|
|
||||||
|
return matches, matching_scores
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize keypoints locations based on image image_shape
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
|
||||||
|
Keypoints locations in (x, y) format.
|
||||||
|
height (`int`):
|
||||||
|
Image height.
|
||||||
|
width (`int`):
|
||||||
|
Image width.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
|
||||||
|
"""
|
||||||
|
size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
|
||||||
|
shift = size / 2
|
||||||
|
scale = size.max(-1).values / 2
|
||||||
|
keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
|
||||||
|
return keypoints
|
||||||
|
|
||||||
|
|
||||||
|
@auto_docstring(
|
||||||
|
custom_intro="""
|
||||||
|
LightGlue model taking images as inputs and outputting the matching of them.
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
class LightGlueForKeypointMatching(LightGluePreTrainedModel):
|
||||||
|
"""
|
||||||
|
LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
|
||||||
|
SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
|
||||||
|
It consists of :
|
||||||
|
1. Keypoint Encoder
|
||||||
|
2. A Graph Neural Network with self and cross attention layers
|
||||||
|
3. Matching Assignment layers
|
||||||
|
|
||||||
|
The correspondence ids use -1 to indicate non-matching points.
|
||||||
|
|
||||||
|
Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
|
||||||
|
In ICCV 2023. https://arxiv.org/pdf/2306.13643.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LightGlueConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
|
||||||
|
|
||||||
|
self.descriptor_dim = config.descriptor_dim
|
||||||
|
self.num_layers = config.num_hidden_layers
|
||||||
|
self.filter_threshold = config.filter_threshold
|
||||||
|
self.depth_confidence = config.depth_confidence
|
||||||
|
self.width_confidence = config.width_confidence
|
||||||
|
|
||||||
|
if self.descriptor_dim != config.keypoint_detector_config.descriptor_decoder_dim:
|
||||||
|
self.input_projection = nn.Linear(
|
||||||
|
config.keypoint_detector_config.descriptor_decoder_dim, self.descriptor_dim, bias=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.input_projection = nn.Identity()
|
||||||
|
|
||||||
|
self.positional_encoder = LightGluePositionalEncoder(config)
|
||||||
|
|
||||||
|
self.transformer_layers = nn.ModuleList(
|
||||||
|
[LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.match_assignment_layers = nn.ModuleList(
|
||||||
|
[LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
|
||||||
|
)
|
||||||
|
self.token_confidence = nn.ModuleList(
|
||||||
|
[LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def _get_confidence_threshold(self, layer_index: int) -> float:
|
||||||
|
"""scaled confidence threshold for a given layer"""
|
||||||
|
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
|
||||||
|
return np.clip(threshold, 0, 1)
|
||||||
|
|
||||||
|
def _keypoint_processing(
|
||||||
|
self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
|
||||||
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
descriptors = descriptors.detach().contiguous()
|
||||||
|
projected_descriptors = self.input_projection(descriptors)
|
||||||
|
keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
|
||||||
|
return projected_descriptors, keypoint_encoding_output
|
||||||
|
|
||||||
|
def _get_early_stopped_image_pairs(
|
||||||
|
self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""evaluate whether we should stop inference based on the confidence of the keypoints"""
|
||||||
|
batch_size, _ = mask.shape
|
||||||
|
if layer_index < self.num_layers - 1:
|
||||||
|
# If the current layer is not the last layer, we compute the confidence of the keypoints and check
|
||||||
|
# if we should stop the forward pass through the transformer layers for each pair of images.
|
||||||
|
keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
|
||||||
|
keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
|
||||||
|
threshold = self._get_confidence_threshold(layer_index)
|
||||||
|
ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
|
||||||
|
early_stopped_pairs = ratio_confident > self.depth_confidence
|
||||||
|
else:
|
||||||
|
# If the current layer is the last layer, we stop the forward pass through the transformer layers for
|
||||||
|
# all pairs of images.
|
||||||
|
early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
|
||||||
|
return early_stopped_pairs
|
||||||
|
|
||||||
|
def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
|
||||||
|
if early_stops is not None:
|
||||||
|
descriptors = descriptors[early_stops]
|
||||||
|
mask = mask[early_stops]
|
||||||
|
scores = self.match_assignment_layers[layer_index](descriptors, mask)
|
||||||
|
matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
|
||||||
|
return matches, matching_scores
|
||||||
|
|
||||||
|
def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
|
||||||
|
"""mask points which should be removed"""
|
||||||
|
keep = scores > (1 - self.width_confidence)
|
||||||
|
if confidences is not None: # Low-confidence points are never pruned.
|
||||||
|
keep |= confidences <= self._get_confidence_threshold(layer_index)
|
||||||
|
return keep
|
||||||
|
|
||||||
|
def _do_layer_keypoint_pruning(
|
||||||
|
self,
|
||||||
|
descriptors: torch.Tensor,
|
||||||
|
keypoints: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
prune_output: torch.Tensor,
|
||||||
|
keypoint_confidences: torch.Tensor,
|
||||||
|
layer_index: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
|
||||||
|
descriptors.
|
||||||
|
"""
|
||||||
|
batch_size, _, _ = descriptors.shape
|
||||||
|
descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
|
||||||
|
pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
|
||||||
|
pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
|
||||||
|
|
||||||
|
# For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
|
||||||
|
pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
|
||||||
|
[t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
|
||||||
|
for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
|
||||||
|
)
|
||||||
|
for i in range(batch_size):
|
||||||
|
prune_output[i, pruned_indices[i]] += 1
|
||||||
|
|
||||||
|
# Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
|
||||||
|
pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
|
||||||
|
pad_sequence(pruned_tensor, batch_first=True)
|
||||||
|
for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
|
||||||
|
)
|
||||||
|
pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
|
||||||
|
pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
|
||||||
|
|
||||||
|
return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
|
||||||
|
|
||||||
|
def _concat_early_stopped_outputs(
|
||||||
|
self,
|
||||||
|
early_stops_indices,
|
||||||
|
final_pruned_keypoints_indices,
|
||||||
|
final_pruned_keypoints_iterations,
|
||||||
|
matches,
|
||||||
|
matching_scores,
|
||||||
|
):
|
||||||
|
early_stops_indices = torch.stack(early_stops_indices)
|
||||||
|
matches, final_pruned_keypoints_indices = (
|
||||||
|
pad_sequence(tensor, batch_first=True, padding_value=-1)
|
||||||
|
for tensor in [matches, final_pruned_keypoints_indices]
|
||||||
|
)
|
||||||
|
matching_scores, final_pruned_keypoints_iterations = (
|
||||||
|
pad_sequence(tensor, batch_first=True, padding_value=0)
|
||||||
|
for tensor in [matching_scores, final_pruned_keypoints_iterations]
|
||||||
|
)
|
||||||
|
matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
|
||||||
|
tensor[early_stops_indices]
|
||||||
|
for tensor in [
|
||||||
|
matches,
|
||||||
|
matching_scores,
|
||||||
|
final_pruned_keypoints_indices,
|
||||||
|
final_pruned_keypoints_iterations,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
|
||||||
|
|
||||||
|
def _do_final_keypoint_pruning(
|
||||||
|
self,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
matches: torch.Tensor,
|
||||||
|
matching_scores: torch.Tensor,
|
||||||
|
num_keypoints: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
|
||||||
|
# have tensors from
|
||||||
|
batch_size, _ = indices.shape
|
||||||
|
indices, matches, matching_scores = (
|
||||||
|
tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
|
||||||
|
)
|
||||||
|
indices0 = indices[:, 0]
|
||||||
|
indices1 = indices[:, 1]
|
||||||
|
matches0 = matches[:, 0]
|
||||||
|
matches1 = matches[:, 1]
|
||||||
|
matching_scores0 = matching_scores[:, 0]
|
||||||
|
matching_scores1 = matching_scores[:, 1]
|
||||||
|
|
||||||
|
# Prepare final matches and matching scores
|
||||||
|
_matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
|
||||||
|
_matching_scores = torch.zeros(
|
||||||
|
(batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
|
||||||
|
)
|
||||||
|
# Fill the matches and matching scores for each image pair
|
||||||
|
for i in range(batch_size // 2):
|
||||||
|
_matches[i, 0, indices0[i]] = torch.where(
|
||||||
|
matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
|
||||||
|
)
|
||||||
|
_matches[i, 1, indices1[i]] = torch.where(
|
||||||
|
matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
|
||||||
|
)
|
||||||
|
_matching_scores[i, 0, indices0[i]] = matching_scores0[i]
|
||||||
|
_matching_scores[i, 1, indices1[i]] = matching_scores1[i]
|
||||||
|
return _matches, _matching_scores
|
||||||
|
|
||||||
|
def _match_image_pair(
|
||||||
|
self,
|
||||||
|
keypoints: torch.Tensor,
|
||||||
|
descriptors: torch.Tensor,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
mask: torch.Tensor = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]:
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
if keypoints.shape[2] == 0: # no keypoints
|
||||||
|
shape = keypoints.shape[:-1]
|
||||||
|
return (
|
||||||
|
keypoints.new_full(shape, -1, dtype=torch.int),
|
||||||
|
keypoints.new_zeros(shape),
|
||||||
|
keypoints.new_zeros(shape),
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = keypoints.device
|
||||||
|
batch_size, _, initial_num_keypoints, _ = keypoints.shape
|
||||||
|
num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
|
||||||
|
# (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
|
||||||
|
keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
|
||||||
|
mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
|
||||||
|
descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.descriptor_dim)
|
||||||
|
image_indices = torch.arange(batch_size * 2, device=device)
|
||||||
|
# Keypoint normalization
|
||||||
|
keypoints = normalize_keypoints(keypoints, height, width)
|
||||||
|
|
||||||
|
descriptors, keypoint_encoding_output = self._keypoint_processing(
|
||||||
|
descriptors, keypoints, output_hidden_states=output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
keypoints = keypoint_encoding_output[0]
|
||||||
|
|
||||||
|
# Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
|
||||||
|
# keypoints is above a certain threshold.
|
||||||
|
do_early_stop = self.depth_confidence > 0
|
||||||
|
# Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
|
||||||
|
# the keypoints is below a certain threshold.
|
||||||
|
do_keypoint_pruning = self.width_confidence > 0
|
||||||
|
|
||||||
|
early_stops_indices = []
|
||||||
|
matches = []
|
||||||
|
matching_scores = []
|
||||||
|
final_pruned_keypoints_indices = []
|
||||||
|
final_pruned_keypoints_iterations = []
|
||||||
|
|
||||||
|
pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
|
||||||
|
pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
|
||||||
|
|
||||||
|
for layer_index in range(self.num_layers):
|
||||||
|
input_shape = descriptors.size()
|
||||||
|
if mask is not None:
|
||||||
|
extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
|
||||||
|
else:
|
||||||
|
extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
|
||||||
|
layer_output = self.transformer_layers[layer_index](
|
||||||
|
descriptors,
|
||||||
|
keypoints,
|
||||||
|
attention_mask=extended_attention_mask,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
descriptors, hidden_states, attention = layer_output
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + hidden_states
|
||||||
|
if output_attentions:
|
||||||
|
all_attentions = all_attentions + attention
|
||||||
|
|
||||||
|
if do_early_stop:
|
||||||
|
if layer_index < self.num_layers - 1:
|
||||||
|
# Get the confidence of the keypoints for the current layer
|
||||||
|
keypoint_confidences = self.token_confidence[layer_index](descriptors)
|
||||||
|
|
||||||
|
# Determine which pairs of images should be early stopped based on the confidence of the keypoints for
|
||||||
|
# the current layer.
|
||||||
|
early_stopped_pairs = self._get_early_stopped_image_pairs(
|
||||||
|
keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Early stopping always occurs at the last layer
|
||||||
|
early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
|
||||||
|
|
||||||
|
if torch.any(early_stopped_pairs):
|
||||||
|
# If a pair of images is considered early stopped, we compute the matches for the remaining
|
||||||
|
# keypoints and stop the forward pass through the transformer layers for this pair of images.
|
||||||
|
early_stops = early_stopped_pairs.repeat_interleave(2)
|
||||||
|
early_stopped_image_indices = image_indices[early_stops]
|
||||||
|
early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
|
||||||
|
descriptors, mask, layer_index, early_stops=early_stops
|
||||||
|
)
|
||||||
|
early_stops_indices.extend(list(early_stopped_image_indices))
|
||||||
|
matches.extend(list(early_stopped_matches))
|
||||||
|
matching_scores.extend(list(early_stopped_matching_scores))
|
||||||
|
if do_keypoint_pruning:
|
||||||
|
final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
|
||||||
|
final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
|
||||||
|
|
||||||
|
# Remove image pairs that have been early stopped from the forward pass
|
||||||
|
num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
|
||||||
|
descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
|
||||||
|
(
|
||||||
|
tensor[~early_stops]
|
||||||
|
for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
keypoints = (keypoints_0, keypoint_1)
|
||||||
|
if do_keypoint_pruning:
|
||||||
|
pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
|
||||||
|
(
|
||||||
|
tensor[~early_stops]
|
||||||
|
for tensor in [
|
||||||
|
pruned_keypoints_indices,
|
||||||
|
pruned_keypoints_iterations,
|
||||||
|
keypoint_confidences,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# If all pairs of images are early stopped, we stop the forward pass through the transformer
|
||||||
|
# layers for all pairs of images.
|
||||||
|
if torch.all(early_stopped_pairs):
|
||||||
|
break
|
||||||
|
|
||||||
|
if do_keypoint_pruning:
|
||||||
|
# Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
|
||||||
|
# the keypoints is below a certain threshold.
|
||||||
|
descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
|
||||||
|
self._do_layer_keypoint_pruning(
|
||||||
|
descriptors,
|
||||||
|
keypoints,
|
||||||
|
mask,
|
||||||
|
pruned_keypoints_indices,
|
||||||
|
pruned_keypoints_iterations,
|
||||||
|
keypoint_confidences,
|
||||||
|
layer_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_early_stop and do_keypoint_pruning:
|
||||||
|
# Concatenate early stopped outputs together and perform final keypoint pruning
|
||||||
|
final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
|
||||||
|
self._concat_early_stopped_outputs(
|
||||||
|
early_stops_indices,
|
||||||
|
final_pruned_keypoints_indices,
|
||||||
|
final_pruned_keypoints_iterations,
|
||||||
|
matches,
|
||||||
|
matching_scores,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
matches, matching_scores = self._do_final_keypoint_pruning(
|
||||||
|
final_pruned_keypoints_indices,
|
||||||
|
matches,
|
||||||
|
matching_scores,
|
||||||
|
initial_num_keypoints,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
|
||||||
|
final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
|
||||||
|
|
||||||
|
final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
|
||||||
|
batch_size, 2, initial_num_keypoints
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
matches,
|
||||||
|
matching_scores,
|
||||||
|
final_pruned_keypoints_iterations,
|
||||||
|
all_hidden_states,
|
||||||
|
all_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@auto_docstring
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.FloatTensor,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, LightGlueKeypointMatchingOutput]:
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
raise ValueError("LightGlue is not trainable, no labels should be provided.")
|
||||||
|
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
|
||||||
|
raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
|
||||||
|
|
||||||
|
batch_size, _, channels, height, width = pixel_values.shape
|
||||||
|
pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
|
||||||
|
keypoint_detections = self.keypoint_detector(pixel_values)
|
||||||
|
|
||||||
|
keypoints, _, descriptors, mask = keypoint_detections[:4]
|
||||||
|
keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
|
||||||
|
descriptors = descriptors.reshape(batch_size, 2, -1, self.descriptor_dim).to(pixel_values)
|
||||||
|
mask = mask.reshape(batch_size, 2, -1)
|
||||||
|
|
||||||
|
absolute_keypoints = keypoints.clone()
|
||||||
|
absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
|
||||||
|
absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
|
||||||
|
|
||||||
|
matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
|
||||||
|
absolute_keypoints,
|
||||||
|
descriptors,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
mask=mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LightGlueKeypointMatchingOutput(
|
||||||
|
loss=loss,
|
||||||
|
matches=matches,
|
||||||
|
matching_scores=matching_scores,
|
||||||
|
keypoints=keypoints,
|
||||||
|
prune=prune,
|
||||||
|
mask=mask,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attentions=attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"]
|
||||||
1000
src/transformers/models/lightglue/modular_lightglue.py
Normal file
1000
src/transformers/models/lightglue/modular_lightglue.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ... import is_torch_available, is_vision_available
|
|
||||||
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
|
||||||
from ...image_transforms import resize, to_channel_dimension_format
|
from ...image_transforms import resize, to_channel_dimension_format
|
||||||
from ...image_utils import (
|
from ...image_utils import (
|
||||||
@@ -29,7 +28,9 @@ from ...image_utils import (
|
|||||||
infer_channel_dimension_format,
|
infer_channel_dimension_format,
|
||||||
is_pil_image,
|
is_pil_image,
|
||||||
is_scaled_image,
|
is_scaled_image,
|
||||||
|
is_torch_available,
|
||||||
is_valid_image,
|
is_valid_image,
|
||||||
|
is_vision_available,
|
||||||
to_numpy_array,
|
to_numpy_array,
|
||||||
valid_images,
|
valid_images,
|
||||||
validate_preprocess_arguments,
|
validate_preprocess_arguments,
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class SuperPointInterestPointDecoder(nn.Module):
|
|||||||
keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints)
|
keypoints, scores = top_k_keypoints(keypoints, scores, self.max_keypoints)
|
||||||
|
|
||||||
# Convert (y, x) to (x, y)
|
# Convert (y, x) to (x, y)
|
||||||
keypoints = torch.flip(keypoints, [1]).float()
|
keypoints = torch.flip(keypoints, [1]).to(scores.dtype)
|
||||||
|
|
||||||
return keypoints, scores
|
return keypoints, scores
|
||||||
|
|
||||||
|
|||||||
@@ -179,6 +179,7 @@ from .import_utils import (
|
|||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_liger_kernel_available,
|
is_liger_kernel_available,
|
||||||
is_lomo_available,
|
is_lomo_available,
|
||||||
|
is_matplotlib_available,
|
||||||
is_mlx_available,
|
is_mlx_available,
|
||||||
is_natten_available,
|
is_natten_available,
|
||||||
is_ninja_available,
|
is_ninja_available,
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ _triton_available = _is_package_available("triton")
|
|||||||
_spqr_available = _is_package_available("spqr_quant")
|
_spqr_available = _is_package_available("spqr_quant")
|
||||||
_rich_available = _is_package_available("rich")
|
_rich_available = _is_package_available("rich")
|
||||||
_kernels_available = _is_package_available("kernels")
|
_kernels_available = _is_package_available("kernels")
|
||||||
|
_matplotlib_available = _is_package_available("matplotlib")
|
||||||
|
|
||||||
_torch_version = "N/A"
|
_torch_version = "N/A"
|
||||||
_torch_available = False
|
_torch_available = False
|
||||||
@@ -1443,6 +1444,10 @@ def is_rich_available():
|
|||||||
return _rich_available
|
return _rich_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_matplotlib_available():
|
||||||
|
return _matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
def check_torch_load_is_safe():
|
def check_torch_load_is_safe():
|
||||||
if not is_torch_greater_or_equal("2.6"):
|
if not is_torch_greater_or_equal("2.6"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
0
tests/models/lightglue/__init__.py
Normal file
0
tests/models/lightglue/__init__.py
Normal file
96
tests/models/lightglue/test_image_processing_lightglue.py
Normal file
96
tests/models/lightglue/test_image_processing_lightglue.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from tests.models.superglue.test_image_processing_superglue import (
|
||||||
|
SuperGlueImageProcessingTest,
|
||||||
|
SuperGlueImageProcessingTester,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import require_torch, require_vision
|
||||||
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.models.lightglue.modeling_lightglue import LightGlueKeypointMatchingOutput
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from transformers import LightGlueImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def random_array(size):
|
||||||
|
return np.random.randint(255, size=size)
|
||||||
|
|
||||||
|
|
||||||
|
def random_tensor(size):
|
||||||
|
return torch.rand(size)
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueImageProcessingTester(SuperGlueImageProcessingTester):
|
||||||
|
"""Tester for LightGlueImageProcessor"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=6,
|
||||||
|
num_channels=3,
|
||||||
|
image_size=18,
|
||||||
|
min_resolution=30,
|
||||||
|
max_resolution=400,
|
||||||
|
do_resize=True,
|
||||||
|
size=None,
|
||||||
|
do_grayscale=True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
parent, batch_size, num_channels, image_size, min_resolution, max_resolution, do_resize, size, do_grayscale
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_keypoint_matching_output(self, pixel_values):
|
||||||
|
"""Prepare a fake output for the keypoint matching model with random matches between 50 keypoints per image."""
|
||||||
|
max_number_keypoints = 50
|
||||||
|
batch_size = len(pixel_values)
|
||||||
|
mask = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int)
|
||||||
|
keypoints = torch.zeros((batch_size, 2, max_number_keypoints, 2))
|
||||||
|
matches = torch.full((batch_size, 2, max_number_keypoints), -1, dtype=torch.int)
|
||||||
|
scores = torch.zeros((batch_size, 2, max_number_keypoints))
|
||||||
|
prune = torch.zeros((batch_size, 2, max_number_keypoints), dtype=torch.int)
|
||||||
|
for i in range(batch_size):
|
||||||
|
random_number_keypoints0 = np.random.randint(10, max_number_keypoints)
|
||||||
|
random_number_keypoints1 = np.random.randint(10, max_number_keypoints)
|
||||||
|
random_number_matches = np.random.randint(5, min(random_number_keypoints0, random_number_keypoints1))
|
||||||
|
mask[i, 0, :random_number_keypoints0] = 1
|
||||||
|
mask[i, 1, :random_number_keypoints1] = 1
|
||||||
|
keypoints[i, 0, :random_number_keypoints0] = torch.rand((random_number_keypoints0, 2))
|
||||||
|
keypoints[i, 1, :random_number_keypoints1] = torch.rand((random_number_keypoints1, 2))
|
||||||
|
random_matches_indices0 = torch.randperm(random_number_keypoints1, dtype=torch.int)[:random_number_matches]
|
||||||
|
random_matches_indices1 = torch.randperm(random_number_keypoints0, dtype=torch.int)[:random_number_matches]
|
||||||
|
matches[i, 0, random_matches_indices1] = random_matches_indices0
|
||||||
|
matches[i, 1, random_matches_indices0] = random_matches_indices1
|
||||||
|
scores[i, 0, random_matches_indices1] = torch.rand((random_number_matches,))
|
||||||
|
scores[i, 1, random_matches_indices0] = torch.rand((random_number_matches,))
|
||||||
|
return LightGlueKeypointMatchingOutput(
|
||||||
|
mask=mask, keypoints=keypoints, matches=matches, matching_scores=scores, prune=prune
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class LightGlueImageProcessingTest(SuperGlueImageProcessingTest, unittest.TestCase):
|
||||||
|
image_processing_class = LightGlueImageProcessor if is_vision_available() else None
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
self.image_processor_tester = LightGlueImageProcessingTester(self)
|
||||||
584
tests/models/lightglue/test_modeling_lightglue.py
Normal file
584
tests/models/lightglue/test_modeling_lightglue.py
Normal file
@@ -0,0 +1,584 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from transformers.models.lightglue.configuration_lightglue import LightGlueConfig
|
||||||
|
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||||
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
from ...test_configuration_common import ConfigTester
|
||||||
|
from ...test_modeling_common import ModelTesterMixin, floats_tensor
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import LightGlueForKeypointMatching
|
||||||
|
|
||||||
|
if is_vision_available():
|
||||||
|
from transformers import AutoImageProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class LightGlueModelTester:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
parent,
|
||||||
|
batch_size=2,
|
||||||
|
image_width=80,
|
||||||
|
image_height=60,
|
||||||
|
keypoint_detector_config={
|
||||||
|
"encoder_hidden_sizes": [32, 32, 64],
|
||||||
|
"decoder_hidden_size": 64,
|
||||||
|
"keypoint_decoder_dim": 65,
|
||||||
|
"descriptor_decoder_dim": 64,
|
||||||
|
"keypoint_threshold": 0.005,
|
||||||
|
"max_keypoints": 256,
|
||||||
|
"nms_radius": 4,
|
||||||
|
"border_removal_distance": 4,
|
||||||
|
},
|
||||||
|
descriptor_dim: int = 64,
|
||||||
|
num_layers: int = 2,
|
||||||
|
num_heads: int = 4,
|
||||||
|
depth_confidence: float = 1.0,
|
||||||
|
width_confidence: float = 1.0,
|
||||||
|
filter_threshold: float = 0.1,
|
||||||
|
matching_threshold: float = 0.0,
|
||||||
|
):
|
||||||
|
self.parent = parent
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.image_width = image_width
|
||||||
|
self.image_height = image_height
|
||||||
|
|
||||||
|
self.keypoint_detector_config = keypoint_detector_config
|
||||||
|
self.descriptor_dim = descriptor_dim
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.depth_confidence = depth_confidence
|
||||||
|
self.width_confidence = width_confidence
|
||||||
|
self.filter_threshold = filter_threshold
|
||||||
|
self.matching_threshold = matching_threshold
|
||||||
|
|
||||||
|
def prepare_config_and_inputs(self):
|
||||||
|
# LightGlue expects a grayscale image as input
|
||||||
|
pixel_values = floats_tensor([self.batch_size, 2, 3, self.image_height, self.image_width])
|
||||||
|
config = self.get_config()
|
||||||
|
return config, pixel_values
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return LightGlueConfig(
|
||||||
|
keypoint_detector_config=self.keypoint_detector_config,
|
||||||
|
descriptor_dim=self.descriptor_dim,
|
||||||
|
num_hidden_layers=self.num_layers,
|
||||||
|
num_attention_heads=self.num_heads,
|
||||||
|
depth_confidence=self.depth_confidence,
|
||||||
|
width_confidence=self.width_confidence,
|
||||||
|
filter_threshold=self.filter_threshold,
|
||||||
|
matching_threshold=self.matching_threshold,
|
||||||
|
attn_implementation="eager",
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_and_check_model(self, config, pixel_values):
|
||||||
|
model = LightGlueForKeypointMatching(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
result = model(pixel_values)
|
||||||
|
maximum_num_matches = result.mask.shape[-1]
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.keypoints.shape,
|
||||||
|
(self.batch_size, 2, maximum_num_matches, 2),
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.matches.shape,
|
||||||
|
(self.batch_size, 2, maximum_num_matches),
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.matching_scores.shape,
|
||||||
|
(self.batch_size, 2, maximum_num_matches),
|
||||||
|
)
|
||||||
|
self.parent.assertEqual(
|
||||||
|
result.prune.shape,
|
||||||
|
(self.batch_size, 2, maximum_num_matches),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_config_and_inputs_for_common(self):
|
||||||
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
config, pixel_values = config_and_inputs
|
||||||
|
inputs_dict = {"pixel_values": pixel_values}
|
||||||
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class LightGlueModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
|
all_model_classes = (LightGlueForKeypointMatching,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = () if is_torch_available() else ()
|
||||||
|
|
||||||
|
test_pruning = False
|
||||||
|
test_resize_embeddings = False
|
||||||
|
test_head_masking = False
|
||||||
|
has_attentions = True
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.model_tester = LightGlueModelTester(self)
|
||||||
|
self.config_tester = ConfigTester(self, config_class=LightGlueConfig, has_text_modality=False, hidden_size=37)
|
||||||
|
|
||||||
|
def test_config(self):
|
||||||
|
self.config_tester.create_and_test_config_to_json_string()
|
||||||
|
self.config_tester.create_and_test_config_to_json_file()
|
||||||
|
self.config_tester.create_and_test_config_from_and_save_pretrained()
|
||||||
|
self.config_tester.create_and_test_config_with_num_labels()
|
||||||
|
self.config_tester.check_config_can_be_init_without_params()
|
||||||
|
self.config_tester.check_config_arguments_init()
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching does not use inputs_embeds")
|
||||||
|
def test_inputs_embeds(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching does not support input and output embeddings")
|
||||||
|
def test_model_get_set_embeddings(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching does not use feedforward chunking")
|
||||||
|
def test_feed_forward_chunking(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
|
||||||
|
def test_training(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
|
||||||
|
def test_training_gradient_checkpointing(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlueForKeypointMatching is not trainable")
|
||||||
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.skip(reason="LightGlue does not output any loss term in the forward pass")
|
||||||
|
def test_retain_grad_hidden_states_attentions(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_model(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_forward_signature(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
signature = inspect.signature(model.forward)
|
||||||
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
|
expected_arg_names = ["pixel_values"]
|
||||||
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||||
|
|
||||||
|
def test_hidden_states_output(self):
|
||||||
|
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
hidden_states = outputs.hidden_states
|
||||||
|
maximum_num_matches = outputs.mask.shape[-1]
|
||||||
|
|
||||||
|
hidden_states_sizes = [
|
||||||
|
self.model_tester.descriptor_dim,
|
||||||
|
self.model_tester.descriptor_dim,
|
||||||
|
self.model_tester.descriptor_dim * 2,
|
||||||
|
self.model_tester.descriptor_dim,
|
||||||
|
self.model_tester.descriptor_dim,
|
||||||
|
self.model_tester.descriptor_dim * 2,
|
||||||
|
self.model_tester.descriptor_dim,
|
||||||
|
] * self.model_tester.num_layers
|
||||||
|
|
||||||
|
for i, hidden_states_size in enumerate(hidden_states_sizes):
|
||||||
|
self.assertListEqual(
|
||||||
|
list(hidden_states[i].shape[-2:]),
|
||||||
|
[maximum_num_matches, hidden_states_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_hidden_states"] = True
|
||||||
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
# check that output_hidden_states also work using config
|
||||||
|
del inputs_dict["output_hidden_states"]
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
check_hidden_states_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
def test_attention_outputs(self):
|
||||||
|
def check_attention_output(inputs_dict, config, model_class):
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
attentions = outputs.attentions
|
||||||
|
maximum_num_matches = outputs.mask.shape[-1]
|
||||||
|
|
||||||
|
expected_attention_shape = [self.model_tester.num_heads, maximum_num_matches, maximum_num_matches]
|
||||||
|
|
||||||
|
for i, attention in enumerate(attentions):
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attention.shape[-3:]),
|
||||||
|
expected_attention_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict["output_attentions"] = True
|
||||||
|
check_attention_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
# check that output_hidden_states also work using config
|
||||||
|
del inputs_dict["output_attentions"]
|
||||||
|
config.output_attentions = True
|
||||||
|
|
||||||
|
check_attention_output(inputs_dict, config, model_class)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
from_pretrained_ids = ["ETH-CVG/lightglue_superpoint"]
|
||||||
|
for model_name in from_pretrained_ids:
|
||||||
|
model = LightGlueForKeypointMatching.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
# Copied from tests.models.superglue.test_modeling_superglue.SuperGlueModelTest.test_forward_labels_should_be_none
|
||||||
|
def test_forward_labels_should_be_none(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
# Provide an arbitrary sized Tensor as labels to model inputs
|
||||||
|
model_inputs["labels"] = torch.rand((128, 128))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
model(**model_inputs)
|
||||||
|
self.assertEqual(ValueError, cm.exception.__class__)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_imgs():
|
||||||
|
dataset = load_dataset("hf-internal-testing/image-matching-test-dataset", split="train")
|
||||||
|
image0 = dataset[0]["image"]
|
||||||
|
image1 = dataset[1]["image"]
|
||||||
|
image2 = dataset[2]["image"]
|
||||||
|
# [image1, image1] on purpose to test the model early stopping
|
||||||
|
return [[image2, image0], [image1, image1]]
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
class LightGlueModelIntegrationTest(unittest.TestCase):
|
||||||
|
@cached_property
|
||||||
|
def default_image_processor(self):
|
||||||
|
return AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint") if is_vision_available() else None
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference(self):
|
||||||
|
model = LightGlueForKeypointMatching.from_pretrained(
|
||||||
|
"ETH-CVG/lightglue_superpoint", attn_implementation="eager"
|
||||||
|
).to(torch_device)
|
||||||
|
preprocessor = self.default_image_processor
|
||||||
|
images = prepare_imgs()
|
||||||
|
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
||||||
|
|
||||||
|
predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item()
|
||||||
|
predicted_matches_values0 = outputs.matches[0, 0, 10:30]
|
||||||
|
predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30]
|
||||||
|
|
||||||
|
predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item()
|
||||||
|
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
|
||||||
|
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
|
||||||
|
|
||||||
|
expected_number_of_matches0 = 140
|
||||||
|
expected_matches_values0 = torch.tensor(
|
||||||
|
[14, -1, -1, 15, 17, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
expected_matching_scores_values0 = torch.tensor(
|
||||||
|
[0.3796, 0, 0, 0.3772, 0.4439, 0.2411, 0, 0, 0.0032, 0, 0, 0, 0.2997, 0, 0, 0.6762, 0, 0.8826, 0, 0.5583],
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_number_of_matches1 = 866
|
||||||
|
expected_matches_values1 = torch.tensor(
|
||||||
|
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
expected_matching_scores_values1 = torch.tensor(
|
||||||
|
[
|
||||||
|
0.6188,0.7817,0.5686,0.9353,0.9801,0.9193,0.8632,0.9111,0.9821,0.5496,
|
||||||
|
0.9906,0.8682,0.9679,0.9914,0.9318,0.1910,0.9669,0.3240,0.9971,0.9923,
|
||||||
|
],
|
||||||
|
device=torch_device
|
||||||
|
) # fmt:skip
|
||||||
|
|
||||||
|
# expected_early_stopping_layer = 2
|
||||||
|
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
|
||||||
|
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
|
||||||
|
# self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
|
||||||
|
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
|
||||||
|
specific test example). The consequence of having different number of keypoints is that the number of matches
|
||||||
|
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
|
||||||
|
match. The matching scores will also be different, as the keypoints are different. The checks here are less
|
||||||
|
strict to account for these inconsistencies.
|
||||||
|
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
|
||||||
|
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
|
||||||
|
|
||||||
|
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
|
||||||
|
Such CUDA inconsistencies can be found
|
||||||
|
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
|
||||||
|
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2))
|
||||||
|
< 4
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2))
|
||||||
|
< 4
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
|
||||||
|
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_without_early_stop(self):
|
||||||
|
model = LightGlueForKeypointMatching.from_pretrained(
|
||||||
|
"ETH-CVG/lightglue_superpoint", attn_implementation="eager", depth_confidence=1.0
|
||||||
|
).to(torch_device)
|
||||||
|
preprocessor = self.default_image_processor
|
||||||
|
images = prepare_imgs()
|
||||||
|
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
||||||
|
|
||||||
|
predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item()
|
||||||
|
predicted_matches_values0 = outputs.matches[0, 0, 10:30]
|
||||||
|
predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30]
|
||||||
|
|
||||||
|
predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item()
|
||||||
|
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
|
||||||
|
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
|
||||||
|
|
||||||
|
expected_number_of_matches0 = 134
|
||||||
|
expected_matches_values0 = torch.tensor(
|
||||||
|
[-1, -1, 17, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64
|
||||||
|
).to(torch_device)
|
||||||
|
expected_matching_scores_values0 = torch.tensor(
|
||||||
|
[0.0083, 0, 0.2022, 0.0621, 0, 0.0828, 0, 0, 0.0003, 0, 0, 0, 0.0960, 0, 0, 0.6940, 0, 0.7167, 0, 0.1512]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
expected_number_of_matches1 = 862
|
||||||
|
expected_matches_values1 = torch.tensor(
|
||||||
|
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=torch.int64
|
||||||
|
).to(torch_device)
|
||||||
|
expected_matching_scores_values1 = torch.tensor(
|
||||||
|
[
|
||||||
|
0.4772,
|
||||||
|
0.3781,
|
||||||
|
0.0631,
|
||||||
|
0.9559,
|
||||||
|
0.8746,
|
||||||
|
0.9271,
|
||||||
|
0.4882,
|
||||||
|
0.5406,
|
||||||
|
0.9439,
|
||||||
|
0.1526,
|
||||||
|
0.5028,
|
||||||
|
0.4107,
|
||||||
|
0.5591,
|
||||||
|
0.9130,
|
||||||
|
0.7572,
|
||||||
|
0.0302,
|
||||||
|
0.4532,
|
||||||
|
0.0893,
|
||||||
|
0.9490,
|
||||||
|
0.4880,
|
||||||
|
]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# expected_early_stopping_layer = 2
|
||||||
|
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
|
||||||
|
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
|
||||||
|
# self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
|
||||||
|
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
|
||||||
|
specific test example). The consequence of having different number of keypoints is that the number of matches
|
||||||
|
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
|
||||||
|
match. The matching scores will also be different, as the keypoints are different. The checks here are less
|
||||||
|
strict to account for these inconsistencies.
|
||||||
|
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
|
||||||
|
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
|
||||||
|
|
||||||
|
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
|
||||||
|
Such CUDA inconsistencies can be found
|
||||||
|
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
|
||||||
|
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2))
|
||||||
|
< 4
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2))
|
||||||
|
< 4
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
|
||||||
|
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_inference_without_early_stop_and_keypoint_pruning(self):
|
||||||
|
model = LightGlueForKeypointMatching.from_pretrained(
|
||||||
|
"ETH-CVG/lightglue_superpoint",
|
||||||
|
attn_implementation="eager",
|
||||||
|
depth_confidence=1.0,
|
||||||
|
width_confidence=1.0,
|
||||||
|
).to(torch_device)
|
||||||
|
preprocessor = self.default_image_processor
|
||||||
|
images = prepare_imgs()
|
||||||
|
inputs = preprocessor(images=images, return_tensors="pt").to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(**inputs, output_hidden_states=True, output_attentions=True)
|
||||||
|
|
||||||
|
predicted_number_of_matches0 = torch.sum(outputs.matches[0][0] != -1).item()
|
||||||
|
predicted_matches_values0 = outputs.matches[0, 0, 10:30]
|
||||||
|
predicted_matching_scores_values0 = outputs.matching_scores[0, 0, 10:30]
|
||||||
|
|
||||||
|
predicted_number_of_matches1 = torch.sum(outputs.matches[1][0] != -1).item()
|
||||||
|
predicted_matches_values1 = outputs.matches[1, 0, 10:30]
|
||||||
|
predicted_matching_scores_values1 = outputs.matching_scores[1, 0, 10:30]
|
||||||
|
|
||||||
|
expected_number_of_matches0 = 144
|
||||||
|
expected_matches_values0 = torch.tensor(
|
||||||
|
[-1, -1, 17, -1, -1, 13, -1, -1, -1, -1, -1, -1, 5, -1, -1, 19, -1, 10, -1, 11], dtype=torch.int64
|
||||||
|
).to(torch_device)
|
||||||
|
expected_matching_scores_values0 = torch.tensor(
|
||||||
|
[
|
||||||
|
0.0699,
|
||||||
|
0.0302,
|
||||||
|
0.3356,
|
||||||
|
0.0820,
|
||||||
|
0,
|
||||||
|
0.2266,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0.0241,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0.1674,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0.8114,
|
||||||
|
0,
|
||||||
|
0.8120,
|
||||||
|
0,
|
||||||
|
0.2936,
|
||||||
|
]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
expected_number_of_matches1 = 862
|
||||||
|
expected_matches_values1 = torch.tensor(
|
||||||
|
[10, 11, -1, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, -1, 26, -1, 28, 29], dtype=torch.int64
|
||||||
|
).to(torch_device)
|
||||||
|
expected_matching_scores_values1 = torch.tensor(
|
||||||
|
[
|
||||||
|
0.4772,
|
||||||
|
0.3781,
|
||||||
|
0.0631,
|
||||||
|
0.9559,
|
||||||
|
0.8746,
|
||||||
|
0.9271,
|
||||||
|
0.4882,
|
||||||
|
0.5406,
|
||||||
|
0.9439,
|
||||||
|
0.1526,
|
||||||
|
0.5028,
|
||||||
|
0.4107,
|
||||||
|
0.5591,
|
||||||
|
0.9130,
|
||||||
|
0.7572,
|
||||||
|
0.0302,
|
||||||
|
0.4532,
|
||||||
|
0.0893,
|
||||||
|
0.9490,
|
||||||
|
0.4880,
|
||||||
|
]
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
|
# expected_early_stopping_layer = 2
|
||||||
|
# predicted_early_stopping_layer = torch.max(outputs.prune[1]).item()
|
||||||
|
# self.assertEqual(predicted_early_stopping_layer, expected_early_stopping_layer)
|
||||||
|
# self.assertEqual(predicted_number_of_matches, expected_second_number_of_matches)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Because of inconsistencies introduced between CUDA versions, the checks here are less strict. SuperGlue relies
|
||||||
|
on SuperPoint, which may, depending on CUDA version, return different number of keypoints (866 or 867 in this
|
||||||
|
specific test example). The consequence of having different number of keypoints is that the number of matches
|
||||||
|
will also be different. In the 20 first matches being checked, having one keypoint less will result in 1 less
|
||||||
|
match. The matching scores will also be different, as the keypoints are different. The checks here are less
|
||||||
|
strict to account for these inconsistencies.
|
||||||
|
Therefore, the test checks that the predicted number of matches, matches and matching scores are close to the
|
||||||
|
expected values, individually. Here, the tolerance of the number of values changing is set to 2.
|
||||||
|
|
||||||
|
This was discussed [here](https://github.com/huggingface/transformers/pull/29886#issuecomment-2482752787)
|
||||||
|
Such CUDA inconsistencies can be found
|
||||||
|
[here](https://github.com/huggingface/transformers/pull/33200/files#r1785980300)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.assertTrue(abs(predicted_number_of_matches0 - expected_number_of_matches0) < 4)
|
||||||
|
self.assertTrue(abs(predicted_number_of_matches1 - expected_number_of_matches1) < 4)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.sum(~torch.isclose(predicted_matching_scores_values0, expected_matching_scores_values0, atol=1e-2))
|
||||||
|
< 4
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.sum(~torch.isclose(predicted_matching_scores_values1, expected_matching_scores_values1, atol=1e-2))
|
||||||
|
< 4
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.sum(predicted_matches_values0 != expected_matches_values0) < 4)
|
||||||
|
self.assertTrue(torch.sum(predicted_matches_values1 != expected_matches_values1) < 4)
|
||||||
Reference in New Issue
Block a user