[image-processing] deprecate plot_keypoint_matching, make visualize_keypoint_matching as a standard (#39830)
* fix: deprecate plot_keypoint_matching and make visualize_keypoint_matching for all Keypoint Matching models * refactor: added copied from * fix: make style * fix: repo consistency * fix: make style * docs: added missing method in SuperGlue docs
This commit is contained in:
@@ -107,7 +107,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
|
||||
```py
|
||||
# Easy visualization using the built-in plotting method
|
||||
processor.plot_keypoint_matching(images, processed_outputs)
|
||||
processor.visualize_keypoint_matching(images, processed_outputs)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
@@ -128,7 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
|
||||
- preprocess
|
||||
- post_process_keypoint_matching
|
||||
- plot_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
@@ -103,38 +103,11 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
print(f"Keypoint at {keypoint0.numpy()} matches with keypoint at {keypoint1.numpy()} with score {matching_score}")
|
||||
```
|
||||
|
||||
- The example below demonstrates how to visualize matches between two images.
|
||||
- Visualize the matches between the images using the built-in plotting functionality.
|
||||
|
||||
```py
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# Create side by side image
|
||||
merged_image = np.zeros((max(image1.height, image2.height), image1.width + image2.width, 3))
|
||||
merged_image[: image1.height, : image1.width] = np.array(image1) / 255.0
|
||||
merged_image[: image2.height, image1.width :] = np.array(image2) / 255.0
|
||||
plt.imshow(merged_image)
|
||||
plt.axis("off")
|
||||
|
||||
# Retrieve the keypoints and matches
|
||||
output = processed_outputs[0]
|
||||
keypoints0 = output["keypoints0"]
|
||||
keypoints1 = output["keypoints1"]
|
||||
matching_scores = output["matching_scores"]
|
||||
|
||||
# Plot the matches
|
||||
for keypoint0, keypoint1, matching_score in zip(keypoints0, keypoints1, matching_scores):
|
||||
plt.plot(
|
||||
[keypoint0[0], keypoint1[0] + image1.width],
|
||||
[keypoint0[1], keypoint1[1]],
|
||||
color=plt.get_cmap("RdYlGn")(matching_score.item()),
|
||||
alpha=0.9,
|
||||
linewidth=0.5,
|
||||
)
|
||||
plt.scatter(keypoint0[0], keypoint0[1], c="black", s=2)
|
||||
plt.scatter(keypoint1[0] + image1.width, keypoint1[1], c="black", s=2)
|
||||
|
||||
plt.savefig("matched_image.png", dpi=300, bbox_inches='tight')
|
||||
# Easy visualization using the built-in plotting method
|
||||
processor.visualize_keypoint_matching(images, processed_outputs)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
@@ -155,6 +128,7 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
|
||||
|
||||
- preprocess
|
||||
- post_process_keypoint_matching
|
||||
- visualize_keypoint_matching
|
||||
|
||||
<frameworkcontent>
|
||||
<pt>
|
||||
|
||||
@@ -408,7 +408,7 @@ class EfficientLoFTRImageProcessor(BaseImageProcessor):
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `EfficientLoFTRImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
outputs (List[Dict[str, torch.Tensor]]]):
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -44,6 +45,9 @@ from ...utils.import_utils import requires
|
||||
from .modeling_lightglue import LightGlueKeypointMatchingOutput
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
|
||||
@@ -402,18 +406,88 @@ class LightGlueImageProcessor(BaseImageProcessor):
|
||||
|
||||
return results
|
||||
|
||||
def visualize_keypoint_matching(
|
||||
self,
|
||||
images: ImageInput,
|
||||
keypoint_matching_output: list[dict[str, torch.Tensor]],
|
||||
) -> list["Image.Image"]:
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
|
||||
keypoints as well as the matching between them.
|
||||
"""
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
results = []
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
|
||||
plot_image[:height0, :width0] = image_pair[0]
|
||||
plot_image[:height1, width0:] = image_pair[1]
|
||||
|
||||
plot_image_pil = Image.fromarray(plot_image)
|
||||
draw = ImageDraw.Draw(plot_image_pil)
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
color = self._get_color(matching_score)
|
||||
draw.line(
|
||||
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
|
||||
fill=color,
|
||||
width=3,
|
||||
)
|
||||
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
|
||||
draw.ellipse(
|
||||
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
results.append(plot_image_pil)
|
||||
return results
|
||||
|
||||
def _get_color(self, score):
|
||||
"""Maps a score to a color."""
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return (r, g, b)
|
||||
|
||||
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
|
||||
matplotlib to be installed.
|
||||
|
||||
.. deprecated::
|
||||
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
|
||||
a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
outputs ([`LightGlueKeypointMatchingOutput`]):
|
||||
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
|
||||
"Use `visualize_keypoint_matching` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@@ -20,7 +21,7 @@ from torch import nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...image_utils import ImageInput, to_numpy_array
|
||||
from ...image_utils import ImageInput, is_vision_available, to_numpy_array
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
@@ -35,6 +36,10 @@ from ..superglue.image_processing_superglue import SuperGlueImageProcessor, vali
|
||||
from ..superpoint import SuperPointConfig
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@@ -220,18 +225,90 @@ class LightGlueImageProcessor(SuperGlueImageProcessor):
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
|
||||
|
||||
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
|
||||
def visualize_keypoint_matching(
|
||||
self,
|
||||
images: ImageInput,
|
||||
keypoint_matching_output: list[dict[str, torch.Tensor]],
|
||||
) -> list["Image.Image"]:
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
|
||||
keypoints as well as the matching between them.
|
||||
"""
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
results = []
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
|
||||
plot_image[:height0, :width0] = image_pair[0]
|
||||
plot_image[:height1, width0:] = image_pair[1]
|
||||
|
||||
plot_image_pil = Image.fromarray(plot_image)
|
||||
draw = ImageDraw.Draw(plot_image_pil)
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
color = self._get_color(matching_score)
|
||||
draw.line(
|
||||
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
|
||||
fill=color,
|
||||
width=3,
|
||||
)
|
||||
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
|
||||
draw.ellipse(
|
||||
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
results.append(plot_image_pil)
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
|
||||
def _get_color(self, score):
|
||||
"""Maps a score to a color."""
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return (r, g, b)
|
||||
|
||||
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
|
||||
matplotlib to be installed.
|
||||
|
||||
.. deprecated::
|
||||
`plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
|
||||
a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
outputs ([`LightGlueKeypointMatchingOutput`]):
|
||||
keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
|
||||
Raw outputs of the model.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
|
||||
"Use `visualize_keypoint_matching` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if is_matplotlib_available():
|
||||
import matplotlib.pyplot as plt
|
||||
else:
|
||||
|
||||
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
||||
|
||||
if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -406,5 +407,68 @@ class SuperGlueImageProcessor(BaseImageProcessor):
|
||||
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->SuperGlue
|
||||
def visualize_keypoint_matching(
|
||||
self,
|
||||
images: ImageInput,
|
||||
keypoint_matching_output: list[dict[str, torch.Tensor]],
|
||||
) -> list["Image.Image"]:
|
||||
"""
|
||||
Plots the image pairs side by side with the detected keypoints as well as the matching between them.
|
||||
|
||||
Args:
|
||||
images (`ImageInput`):
|
||||
Image pairs to plot. Same as `SuperGlueImageProcessor.preprocess`. Expects either a list of 2
|
||||
images or a list of list of 2 images list with pixel values ranging from 0 to 255.
|
||||
keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
|
||||
A post processed keypoint matching output
|
||||
|
||||
Returns:
|
||||
`List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
|
||||
keypoints as well as the matching between them.
|
||||
"""
|
||||
images = validate_and_format_image_pairs(images)
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
|
||||
|
||||
results = []
|
||||
for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
|
||||
height0, width0 = image_pair[0].shape[:2]
|
||||
height1, width1 = image_pair[1].shape[:2]
|
||||
plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
|
||||
plot_image[:height0, :width0] = image_pair[0]
|
||||
plot_image[:height1, width0:] = image_pair[1]
|
||||
|
||||
plot_image_pil = Image.fromarray(plot_image)
|
||||
draw = ImageDraw.Draw(plot_image_pil)
|
||||
|
||||
keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
|
||||
keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
|
||||
for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
|
||||
keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
|
||||
):
|
||||
color = self._get_color(matching_score)
|
||||
draw.line(
|
||||
(keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
|
||||
fill=color,
|
||||
width=3,
|
||||
)
|
||||
draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
|
||||
draw.ellipse(
|
||||
(keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
|
||||
fill="black",
|
||||
)
|
||||
|
||||
results.append(plot_image_pil)
|
||||
return results
|
||||
|
||||
# Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
|
||||
def _get_color(self, score):
|
||||
"""Maps a score to a color."""
|
||||
r = int(255 * (1 - score))
|
||||
g = int(255 * score)
|
||||
b = 0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
__all__ = ["SuperGlueImageProcessor"]
|
||||
|
||||
Reference in New Issue
Block a user