[image-processing] deprecate plot_keypoint_matching, make visualize_keypoint_matching as a standard (#39830)

* fix: deprecate plot_keypoint_matching and make visualize_keypoint_matching for all Keypoint Matching models

* refactor: added copied from

* fix: make style

* fix: repo consistency

* fix: make style

* docs: added missing method in SuperGlue docs
This commit is contained in:
StevenBucaille
2025-08-01 12:29:57 -04:00
committed by GitHub
parent 7b4d9843ba
commit 1ec0feccdd
6 changed files with 225 additions and 36 deletions

View File

@@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
```py ```py
# Easy visualization using the built-in plotting method # Easy visualization using the built-in plotting method
processor.plot_keypoint_matching(images, processed_outputs) processor.visualize_keypoint_matching(images, processed_outputs)
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
@@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- preprocess - preprocess
- post_process_keypoint_matching - post_process_keypoint_matching
- plot_keypoint_matching - visualize_keypoint_matching
<frameworkcontent> <frameworkcontent>
<pt> <pt>

View File

@@ -103,38 +103,11 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}") print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}")
``` ```
- The example below demonstrates how to visualize matches between two images. - Visualize the matches between the images using the built-in plotting functionality.
```py ```py
import matplotlib.pyplot as plt # Easy visualization using the built-in plotting method
import numpy as np processor.visualize_keypoint_matching(images, processed_outputs)
# Create side by side image
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
plt.imshow(merged_image)
plt.axis("off")
# Retrieve the keypoints and matches
output = processed_outputs[0]
keypoints0 = output["keypoints0"]
keypoints1 = output["keypoints1"]
matching_scores = output["matching_scores"]
# Plot the matches
for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores):
plt.plot(
[keypoint0[0], keypoint1[0] + image1.width],
[keypoint0[1], keypoint1[1]],
color=plt.get_cmap("RdYlGn")(matching_score.item()),
alpha=0.9,
linewidth=0.5,
)
plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2)
plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2)
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
``` ```
<div class="flex justify-center"> <div class="flex justify-center">
@@ -155,6 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- preprocess - preprocess
- post_process_keypoint_matching - post_process_keypoint_matching
- visualize_keypoint_matching
<frameworkcontent> <frameworkcontent>
<pt> <pt>

View File

@@ -408,7 +408,7 @@ class EfficientLoFTRImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2 Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255. images or a list of list of 2 images list with pixel values ranging from 0 to 255.
outputs (List[Dict[str, torch.Tensor]]]): keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output A post processed keypoint matching output
Returns: Returns:

View File

@@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
@@ -44,6 +45,9 @@ from ...utils.import_utils import requires
from .modeling_lightglue import LightGlueKeypointMatchingOutput from .modeling_lightglue import LightGlueKeypointMatchingOutput
if is_vision_available():
from PIL import Image, ImageDraw
if is_vision_available(): if is_vision_available():
import PIL import PIL
@@ -402,18 +406,88 @@ class LightGlueImageProcessor(BaseImageProcessor):
return results return results
def visualize_keypoint_matching(
self,
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
""" """
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed. matplotlib to be installed.
.. deprecated::
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
Args: Args:
images (`ImageInput`): images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255. a list of list of 2 images list with pixel values ranging from 0 to 255.
outputs ([`LightGlueKeypointMatchingOutput`]): keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model. Raw outputs of the model.
""" """
warnings.warn(
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
"Use `visualize_keypoint_matching` instead.",
FutureWarning,
)
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
else: else:

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
@@ -20,7 +21,7 @@ from torch import nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...image_utils import ImageInput, to_numpy_array from ...image_utils import ImageInput, is_vision_available, to_numpy_array
from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack from ...processing_utils import Unpack
@@ -35,6 +36,10 @@ from ..superglue.image_processing_superglue import SuperGlueImageProcessor, vali
from ..superpoint import SuperPointConfig from ..superpoint import SuperPointConfig
if is_vision_available():
from PIL import Image, ImageDraw
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -220,18 +225,90 @@ class LightGlueImageProcessor(SuperGlueImageProcessor):
) -> list[dict[str, torch.Tensor]]: ) -> list[dict[str, torch.Tensor]]:
return super().post_process_keypoint_matching(outputs, target_sizes, threshold) return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
def visualize_keypoint_matching(
self,
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
""" """
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed. matplotlib to be installed.
.. deprecated::
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
Args: Args:
images (`ImageInput`): images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255. a list of list of 2 images list with pixel values ranging from 0 to 255.
outputs ([`LightGlueKeypointMatchingOutput`]): keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model. Raw outputs of the model.
""" """
warnings.warn(
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
"Use `visualize_keypoint_matching` instead.",
FutureWarning,
)
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
else: else:

View File

@@ -47,6 +47,7 @@ if TYPE_CHECKING:
if is_vision_available(): if is_vision_available():
import PIL import PIL
from PIL import Image, ImageDraw
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -406,5 +407,68 @@ class SuperGlueImageProcessor(BaseImageProcessor):
return results return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->SuperGlue
def visualize_keypoint_matching(
self,
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `SuperGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
__all__ = ["SuperGlueImageProcessor"] __all__ = ["SuperGlueImageProcessor"]