[image-processing] deprecate plot_keypoint_matching, make visualize_keypoint_matching as a standard (#39830)

* fix: deprecate plot_keypoint_matching and make visualize_keypoint_matching for all Keypoint Matching models

* refactor: added copied from

* fix: make style

* fix: repo consistency

* fix: make style

* docs: added missing method in SuperGlue docs
This commit is contained in:
StevenBucaille
2025-08-01 12:29:57 -04:00
committed by GitHub
parent 7b4d9843ba
commit 1ec0feccdd
6 changed files with 225 additions and 36 deletions

View File

@@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
```py
# Easy visualization using the built-in plotting method
processor.plot_keypoint_matching(images, processed_outputs)
processor.visualize_keypoint_matching(images, processed_outputs)
```
<div class="flex justify-center">
@@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- preprocess
- post_process_keypoint_matching
- plot_keypoint_matching
- visualize_keypoint_matching
<frameworkcontent>
<pt>

View File

@@ -103,38 +103,11 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}")
```
- The example below demonstrates how to visualize matches between two images.
- Visualize the matches between the images using the built-in plotting functionality.
```py
import matplotlib.pyplot as plt
import numpy as np
# Create side by side image
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
plt.imshow(merged_image)
plt.axis("off")
# Retrieve the keypoints and matches
output = processed_outputs[0]
keypoints0 = output["keypoints0"]
keypoints1 = output["keypoints1"]
matching_scores = output["matching_scores"]
# Plot the matches
for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores):
plt.plot(
[keypoint0[0], keypoint1[0] + image1.width],
[keypoint0[1], keypoint1[1]],
color=plt.get_cmap("RdYlGn")(matching_score.item()),
alpha=0.9,
linewidth=0.5,
)
plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2)
plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2)
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
# Easy visualization using the built-in plotting method
processor.visualize_keypoint_matching(images, processed_outputs)
```
<div class="flex justify-center">
@@ -155,6 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- preprocess
- post_process_keypoint_matching
- visualize_keypoint_matching
<frameworkcontent>
<pt>

View File

@@ -408,7 +408,7 @@ class EfficientLoFTRImageProcessor(BaseImageProcessor):
images (`ImageInput`):
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
outputs (List[Dict[str, torch.Tensor]]]):
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:

View File

@@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Optional, Union
import numpy as np
@@ -44,6 +45,9 @@ from ...utils.import_utils import requires
from .modeling_lightglue import LightGlueKeypointMatchingOutput
if is_vision_available():
from PIL import Image, ImageDraw
if is_vision_available():
import PIL
@@ -402,18 +406,88 @@ class LightGlueImageProcessor(BaseImageProcessor):
return results
def visualize_keypoint_matching(
self,
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed.
.. deprecated::
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255.
outputs ([`LightGlueKeypointMatchingOutput`]):
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model.
"""
warnings.warn(
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
"Use `visualize_keypoint_matching` instead.",
FutureWarning,
)
if is_matplotlib_available():
import matplotlib.pyplot as plt
else:

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Callable, Optional, Union
@@ -20,7 +21,7 @@ from torch import nn
from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import PretrainedConfig
from ...image_utils import ImageInput, to_numpy_array
from ...image_utils import ImageInput, is_vision_available, to_numpy_array
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@@ -35,6 +36,10 @@ from ..superglue.image_processing_superglue import SuperGlueImageProcessor, vali
from ..superpoint import SuperPointConfig
if is_vision_available():
from PIL import Image, ImageDraw
logger = logging.get_logger(__name__)
@@ -220,18 +225,90 @@ class LightGlueImageProcessor(SuperGlueImageProcessor):
) -> list[dict[str, torch.Tensor]]:
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
def visualize_keypoint_matching(
self,
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed.
.. deprecated::
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255.
outputs ([`LightGlueKeypointMatchingOutput`]):
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model.
"""
warnings.warn(
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
"Use `visualize_keypoint_matching` instead.",
FutureWarning,
)
if is_matplotlib_available():
import matplotlib.pyplot as plt
else:

View File

@@ -47,6 +47,7 @@ if TYPE_CHECKING:
if is_vision_available():
import PIL
from PIL import Image, ImageDraw
logger = logging.get_logger(__name__)
@@ -406,5 +407,68 @@ class SuperGlueImageProcessor(BaseImageProcessor):
return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->SuperGlue
def visualize_keypoint_matching(
self,
images: ImageInput,
keypoint_matching_output: list[dict[str, torch.Tensor]],
) -> list["Image.Image"]:
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
Args:
images (`ImageInput`):
Image pairs to plot. Same as `SuperGlueImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
keypoints as well as the matching between them.
"""
images = validate_and_format_image_pairs(images)
images = [to_numpy_array(image) for image in images]
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
results = []
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
height0, width0 = image_pair[0].shape[:2]
height1, width1 = image_pair[1].shape[:2]
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
plot_image[:height0, :width0] = image_pair[0]
plot_image[:height1, width0:] = image_pair[1]
plot_image_pil = Image.fromarray(plot_image)
draw = ImageDraw.Draw(plot_image_pil)
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
):
color = self._get_color(matching_score)
draw.line(
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
fill=color,
width=3,
)
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
draw.ellipse(
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
fill="black",
)
results.append(plot_image_pil)
return results
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
def _get_color(self, score):
"""Maps a score to a color."""
r = int(255 * (1 - score))
g = int(255 * score)
b = 0
return (r, g, b)
__all__ = ["SuperGlueImageProcessor"]