diff --git a/docs/source/en/model_doc/lightglue.md b/docs/source/en/model_doc/lightglue.md
index 821d435cf4..0200285732 100644
--- a/docs/source/en/model_doc/lightglue.md
+++ b/docs/source/en/model_doc/lightglue.md
@@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
```py
# Easy visualization using the built-in plotting method
- processor.plot_keypoint_matching(images, processed_outputs)
+ processor.visualize_keypoint_matching(images, processed_outputs)
```
@@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- preprocess
- post_process_keypoint_matching
-- plot_keypoint_matching
+- visualize_keypoint_matching
diff --git a/docs/source/en/model_doc/superglue.md b/docs/source/en/model_doc/superglue.md
index e4f1c8931e..acbf3561ca 100644
--- a/docs/source/en/model_doc/superglue.md
+++ b/docs/source/en/model_doc/superglue.md
@@ -103,38 +103,11 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}")
```
-- The example below demonstrates how to visualize matches between two images.
+- Visualize the matches between the images using the built-in plotting functionality.
```py
- import matplotlib.pyplot as plt
- import numpy as np
-
- # Create side by side image
- merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
- merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
- merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
- plt.imshow(merged_image)
- plt.axis("off")
-
- # Retrieve the keypoints and matches
- output = processed_outputs[0]
- keypoints0 = output["keypoints0"]
- keypoints1 = output["keypoints1"]
- matching_scores = output["matching_scores"]
-
- # Plot the matches
- for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores):
- plt.plot(
- [keypoint0[0], keypoint1[0] + image1.width],
- [keypoint0[1], keypoint1[1]],
- color=plt.get_cmap("RdYlGn")(matching_score.item()),
- alpha=0.9,
- linewidth=0.5,
- )
- plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2)
- plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2)
-
- plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
+ # Easy visualization using the built-in plotting method
+ processor.visualize_keypoint_matching(images, processed_outputs)
```
@@ -155,6 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- preprocess
- post_process_keypoint_matching
+- visualize_keypoint_matching
diff --git a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py
index a1bed128da..5663e21fca 100644
--- a/src/transformers/models/efficientloftr/image_processing_efficientloftr.py
+++ b/src/transformers/models/efficientloftr/image_processing_efficientloftr.py
@@ -408,7 +408,7 @@ class EfficientLoFTRImageProcessor(BaseImageProcessor):
images (`ImageInput`):
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
- outputs (List[Dict[str, torch.Tensor]]]):
+ keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
A post processed keypoint matching output
Returns:
diff --git a/src/transformers/models/lightglue/image_processing_lightglue.py b/src/transformers/models/lightglue/image_processing_lightglue.py
index 124e4b04d5..c389929eea 100644
--- a/src/transformers/models/lightglue/image_processing_lightglue.py
+++ b/src/transformers/models/lightglue/image_processing_lightglue.py
@@ -17,6 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import warnings
from typing import Optional, Union
import numpy as np
@@ -44,6 +45,9 @@ from ...utils.import_utils import requires
from .modeling_lightglue import LightGlueKeypointMatchingOutput
+if is_vision_available():
+ from PIL import Image, ImageDraw
+
if is_vision_available():
import PIL
@@ -402,18 +406,88 @@ class LightGlueImageProcessor(BaseImageProcessor):
return results
+ def visualize_keypoint_matching(
+ self,
+ images: ImageInput,
+ keypoint_matching_output: list[dict[str, torch.Tensor]],
+ ) -> list["Image.Image"]:
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
+ images or a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
+ A post processed keypoint matching output
+
+ Returns:
+ `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
+ keypoints as well as the matching between them.
+ """
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ results = []
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
+ plot_image[:height0, :width0] = image_pair[0]
+ plot_image[:height1, width0:] = image_pair[1]
+
+ plot_image_pil = Image.fromarray(plot_image)
+ draw = ImageDraw.Draw(plot_image_pil)
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ color = self._get_color(matching_score)
+ draw.line(
+ (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
+ fill=color,
+ width=3,
+ )
+ draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
+ draw.ellipse(
+ (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
+ fill="black",
+ )
+
+ results.append(plot_image_pil)
+ return results
+
+ def _get_color(self, score):
+ """Maps a score to a color."""
+ r = int(255 * (1 - score))
+ g = int(255 * score)
+ b = 0
+ return (r, g, b)
+
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed.
+ .. deprecated::
+ `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
+
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255.
- outputs ([`LightGlueKeypointMatchingOutput`]):
+ keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model.
"""
+ warnings.warn(
+ "`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
+ "Use `visualize_keypoint_matching` instead.",
+ FutureWarning,
+ )
+
if is_matplotlib_available():
import matplotlib.pyplot as plt
else:
diff --git a/src/transformers/models/lightglue/modular_lightglue.py b/src/transformers/models/lightglue/modular_lightglue.py
index cefb235fcb..2801727e43 100644
--- a/src/transformers/models/lightglue/modular_lightglue.py
+++ b/src/transformers/models/lightglue/modular_lightglue.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import warnings
from dataclasses import dataclass
from typing import Callable, Optional, Union
@@ -20,7 +21,7 @@ from torch import nn
from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import PretrainedConfig
-from ...image_utils import ImageInput, to_numpy_array
+from ...image_utils import ImageInput, is_vision_available, to_numpy_array
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
@@ -35,6 +36,10 @@ from ..superglue.image_processing_superglue import SuperGlueImageProcessor, vali
from ..superpoint import SuperPointConfig
+if is_vision_available():
+ from PIL import Image, ImageDraw
+
+
logger = logging.get_logger(__name__)
@@ -220,18 +225,90 @@ class LightGlueImageProcessor(SuperGlueImageProcessor):
) -> list[dict[str, torch.Tensor]]:
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
+ # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
+ def visualize_keypoint_matching(
+ self,
+ images: ImageInput,
+ keypoint_matching_output: list[dict[str, torch.Tensor]],
+ ) -> list["Image.Image"]:
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
+ images or a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
+ A post processed keypoint matching output
+
+ Returns:
+ `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
+ keypoints as well as the matching between them.
+ """
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ results = []
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
+ plot_image[:height0, :width0] = image_pair[0]
+ plot_image[:height1, width0:] = image_pair[1]
+
+ plot_image_pil = Image.fromarray(plot_image)
+ draw = ImageDraw.Draw(plot_image_pil)
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ color = self._get_color(matching_score)
+ draw.line(
+ (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
+ fill=color,
+ width=3,
+ )
+ draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
+ draw.ellipse(
+ (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
+ fill="black",
+ )
+
+ results.append(plot_image_pil)
+ return results
+
+ # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
+ def _get_color(self, score):
+ """Maps a score to a color."""
+ r = int(255 * (1 - score))
+ g = int(255 * score)
+ b = 0
+ return (r, g, b)
+
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
"""
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
matplotlib to be installed.
+ .. deprecated::
+ `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
+
Args:
images (`ImageInput`):
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
a list of list of 2 images list with pixel values ranging from 0 to 255.
- outputs ([`LightGlueKeypointMatchingOutput`]):
+ keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
Raw outputs of the model.
"""
+ warnings.warn(
+ "`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
+ "Use `visualize_keypoint_matching` instead.",
+ FutureWarning,
+ )
+
if is_matplotlib_available():
import matplotlib.pyplot as plt
else:
diff --git a/src/transformers/models/superglue/image_processing_superglue.py b/src/transformers/models/superglue/image_processing_superglue.py
index 7fd31a905e..f02e2a9f65 100644
--- a/src/transformers/models/superglue/image_processing_superglue.py
+++ b/src/transformers/models/superglue/image_processing_superglue.py
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
if is_vision_available():
import PIL
+ from PIL import Image, ImageDraw
logger = logging.get_logger(__name__)
@@ -406,5 +407,68 @@ class SuperGlueImageProcessor(BaseImageProcessor):
return results
+ # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->SuperGlue
+ def visualize_keypoint_matching(
+ self,
+ images: ImageInput,
+ keypoint_matching_output: list[dict[str, torch.Tensor]],
+ ) -> list["Image.Image"]:
+ """
+ Plots the image pairs side by side with the detected keypoints as well as the matching between them.
+
+ Args:
+ images (`ImageInput`):
+ Image pairs to plot. Same as `SuperGlueImageProcessor.preprocess`. Expects either a list of 2
+ images or a list of list of 2 images list with pixel values ranging from 0 to 255.
+ keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
+ A post processed keypoint matching output
+
+ Returns:
+ `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
+ keypoints as well as the matching between them.
+ """
+ images = validate_and_format_image_pairs(images)
+ images = [to_numpy_array(image) for image in images]
+ image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
+
+ results = []
+ for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
+ height0, width0 = image_pair[0].shape[:2]
+ height1, width1 = image_pair[1].shape[:2]
+ plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
+ plot_image[:height0, :width0] = image_pair[0]
+ plot_image[:height1, width0:] = image_pair[1]
+
+ plot_image_pil = Image.fromarray(plot_image)
+ draw = ImageDraw.Draw(plot_image_pil)
+
+ keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
+ keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
+ keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
+ ):
+ color = self._get_color(matching_score)
+ draw.line(
+ (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
+ fill=color,
+ width=3,
+ )
+ draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
+ draw.ellipse(
+ (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
+ fill="black",
+ )
+
+ results.append(plot_image_pil)
+ return results
+
+ # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
+ def _get_color(self, score):
+ """Maps a score to a color."""
+ r = int(255 * (1 - score))
+ g = int(255 * score)
+ b = 0
+ return (r, g, b)
+
__all__ = ["SuperGlueImageProcessor"]