diff --git a/docs/source/en/model_doc/lightglue.md b/docs/source/en/model_doc/lightglue.md index 821d435cf4..0200285732 100644 --- a/docs/source/en/model_doc/lightglue.md +++ b/docs/source/en/model_doc/lightglue.md @@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size ```py # Easy visualization using the built-in plotting method - processor.plot_keypoint_matching(images, processed_outputs) + processor.visualize_keypoint_matching(images, processed_outputs) ```
@@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size - preprocess - post_process_keypoint_matching -- plot_keypoint_matching +- visualize_keypoint_matching diff --git a/docs/source/en/model_doc/superglue.md b/docs/source/en/model_doc/superglue.md index e4f1c8931e..acbf3561ca 100644 --- a/docs/source/en/model_doc/superglue.md +++ b/docs/source/en/model_doc/superglue.md @@ -103,38 +103,11 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}") ``` -- The example below demonstrates how to visualize matches between two images. +- Visualize the matches between the images using the built-in plotting functionality. ```py - import matplotlib.pyplot as plt - import numpy as np - - # Create side by side image - merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3)) - merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0 - merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0 - plt.imshow(merged_image) - plt.axis("off") - - # Retrieve the keypoints and matches - output = processed_outputs[0] - keypoints0 = output["keypoints0"] - keypoints1 = output["keypoints1"] - matching_scores = output["matching_scores"] - - # Plot the matches - for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores): - plt.plot( - [keypoint0[0], keypoint1[0] + image1.width], - [keypoint0[1], keypoint1[1]], - color=plt.get_cmap("RdYlGn")(matching_score.item()), - alpha=0.9, - linewidth=0.5, - ) - plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2) - plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2) - - plt.savefig("matched_image.png", dpi=300, bbox_inches='tight') + # Easy visualization using the built-in plotting method + processor.visualize_keypoint_matching(images, processed_outputs) ```
@@ -155,6 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size - preprocess - post_process_keypoint_matching +- visualize_keypoint_matching diff --git a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py index a1bed128da..5663e21fca 100644 --- a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py +++ b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py @@ -408,7 +408,7 @@ class EfficientLoFTRImageProcessor(BaseImageProcessor): images (`ImageInput`): Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2 images or a list of list of 2 images list with pixel values ranging from 0 to 255. - outputs (List[Dict[str, torch.Tensor]]]): + keypoint_matching_output (List[Dict[str, torch.Tensor]]]): A post processed keypoint matching output Returns: diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py index 124e4b04d5..c389929eea 100644 --- a/src/transformers/models/lightglue/image_processing_lightglue.py +++ b/src/transformers/models/lightglue/image_processing_lightglue.py @@ -17,6 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Optional, Union import numpy as np @@ -44,6 +45,9 @@ from ...utils.import_utils import requires from .modeling_lightglue import LightGlueKeypointMatchingOutput +if is_vision_available(): + from PIL import Image, ImageDraw + if is_vision_available(): import PIL @@ -402,18 +406,88 @@ class LightGlueImageProcessor(BaseImageProcessor): return results + def visualize_keypoint_matching( + self, + images: ImageInput, + keypoint_matching_output: list[dict[str, torch.Tensor]], + ) -> list["Image.Image"]: + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 + images or a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output (List[Dict[str, torch.Tensor]]]): + A post processed keypoint matching output + + Returns: + `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected + keypoints as well as the matching between them. + """ + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + results = [] + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8) + plot_image[:height0, :width0] = image_pair[0] + plot_image[:height1, width0:] = image_pair[1] + + plot_image_pil = Image.fromarray(plot_image) + draw = ImageDraw.Draw(plot_image_pil) + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + color = self._get_color(matching_score) + draw.line( + (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y), + fill=color, + width=3, + ) + draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black") + draw.ellipse( + (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2), + fill="black", + ) + + results.append(plot_image_pil) + return results + + def _get_color(self, score): + """Maps a score to a color.""" + r = int(255 * (1 - score)) + g = int(255 * score) + b = 0 + return (r, g, b) + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): """ Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires matplotlib to be installed. + .. deprecated:: + `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead. + Args: images (`ImageInput`): Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or a list of list of 2 images list with pixel values ranging from 0 to 255. - outputs ([`LightGlueKeypointMatchingOutput`]): + keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]): Raw outputs of the model. """ + warnings.warn( + "`plot_keypoint_matching` is deprecated and will be removed in transformers v. " + "Use `visualize_keypoint_matching` instead.", + FutureWarning, + ) + if is_matplotlib_available(): import matplotlib.pyplot as plt else: diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py index cefb235fcb..2801727e43 100644 --- a/src/transformers/models/lightglue/modular_lightglue.py +++ b/src/transformers/models/lightglue/modular_lightglue.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Callable, Optional, Union @@ -20,7 +21,7 @@ from torch import nn from torch.nn.utils.rnn import pad_sequence from ...configuration_utils import PretrainedConfig -from ...image_utils import ImageInput, to_numpy_array +from ...image_utils import ImageInput, is_vision_available, to_numpy_array from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack @@ -35,6 +36,10 @@ from ..superglue.image_processing_superglue import SuperGlueImageProcessor, vali from ..superpoint import SuperPointConfig +if is_vision_available(): + from PIL import Image, ImageDraw + + logger = logging.get_logger(__name__) @@ -220,18 +225,90 @@ class LightGlueImageProcessor(SuperGlueImageProcessor): ) -> list[dict[str, torch.Tensor]]: return super().post_process_keypoint_matching(outputs, target_sizes, threshold) + # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue + def visualize_keypoint_matching( + self, + images: ImageInput, + keypoint_matching_output: list[dict[str, torch.Tensor]], + ) -> list["Image.Image"]: + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 + images or a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output (List[Dict[str, torch.Tensor]]]): + A post processed keypoint matching output + + Returns: + `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected + keypoints as well as the matching between them. + """ + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + results = [] + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8) + plot_image[:height0, :width0] = image_pair[0] + plot_image[:height1, width0:] = image_pair[1] + + plot_image_pil = Image.fromarray(plot_image) + draw = ImageDraw.Draw(plot_image_pil) + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + color = self._get_color(matching_score) + draw.line( + (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y), + fill=color, + width=3, + ) + draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black") + draw.ellipse( + (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2), + fill="black", + ) + + results.append(plot_image_pil) + return results + + # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color + def _get_color(self, score): + """Maps a score to a color.""" + r = int(255 * (1 - score)) + g = int(255 * score) + b = 0 + return (r, g, b) + def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput): """ Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires matplotlib to be installed. + .. deprecated:: + `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead. + Args: images (`ImageInput`): Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or a list of list of 2 images list with pixel values ranging from 0 to 255. - outputs ([`LightGlueKeypointMatchingOutput`]): + keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]): Raw outputs of the model. """ + warnings.warn( + "`plot_keypoint_matching` is deprecated and will be removed in transformers v. " + "Use `visualize_keypoint_matching` instead.", + FutureWarning, + ) + if is_matplotlib_available(): import matplotlib.pyplot as plt else: diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py index 7fd31a905e..f02e2a9f65 100644 --- a/src/transformers/models/superglue/image_processing_superglue.py +++ b/src/transformers/models/superglue/image_processing_superglue.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: if is_vision_available(): import PIL + from PIL import Image, ImageDraw logger = logging.get_logger(__name__) @@ -406,5 +407,68 @@ class SuperGlueImageProcessor(BaseImageProcessor): return results + # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->SuperGlue + def visualize_keypoint_matching( + self, + images: ImageInput, + keypoint_matching_output: list[dict[str, torch.Tensor]], + ) -> list["Image.Image"]: + """ + Plots the image pairs side by side with the detected keypoints as well as the matching between them. + + Args: + images (`ImageInput`): + Image pairs to plot. Same as `SuperGlueImageProcessor.preprocess`. Expects either a list of 2 + images or a list of list of 2 images list with pixel values ranging from 0 to 255. + keypoint_matching_output (List[Dict[str, torch.Tensor]]]): + A post processed keypoint matching output + + Returns: + `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected + keypoints as well as the matching between them. + """ + images = validate_and_format_image_pairs(images) + images = [to_numpy_array(image) for image in images] + image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)] + + results = [] + for image_pair, pair_output in zip(image_pairs, keypoint_matching_output): + height0, width0 = image_pair[0].shape[:2] + height1, width1 = image_pair[1].shape[:2] + plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8) + plot_image[:height0, :width0] = image_pair[0] + plot_image[:height1, width0:] = image_pair[1] + + plot_image_pil = Image.fromarray(plot_image) + draw = ImageDraw.Draw(plot_image_pil) + + keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1) + keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1) + for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip( + keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"] + ): + color = self._get_color(matching_score) + draw.line( + (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y), + fill=color, + width=3, + ) + draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black") + draw.ellipse( + (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2), + fill="black", + ) + + results.append(plot_image_pil) + return results + + # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color + def _get_color(self, score): + """Maps a score to a color.""" + r = int(255 * (1 - score)) + g = int(255 * score) + b = 0 + return (r, g, b) + __all__ = ["SuperGlueImageProcessor"]