From c353f2bb5e29a52a15831d3fbe565f7b7fff3a08 Mon Sep 17 00:00:00 2001 From: Avigyan Sinha <72064090+arkhamHack@users.noreply.github.com> Date: Mon, 28 Jul 2025 23:45:06 +0530 Subject: [PATCH] Superpoint fast image processor (#37804) * feat: superpoint fast image processor * fix: reran fast cli command to generate fast config * feat: updated test cases * fix: removed old model add * fix: format fix * Update src/transformers/models/superpoint/image_processing_superpoint_fast.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * fix: ported to torch and made requested changes * fix: removed changes to init * fix: init fix * fix: init format fix * fixed testcases and ported to torch * fix: format fixes * failed test case fix * fix superpoint fast * fix docstring --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: yonigozlan --- docs/source/en/model_doc/superpoint.md | 5 + .../models/auto/image_processing_auto.py | 7 + .../models/superpoint/__init__.py | 1 + .../superpoint/image_processing_superpoint.py | 12 +- .../image_processing_superpoint_fast.py | 182 ++++++++++++++++++ .../test_image_processing_superpoint.py | 79 ++++---- 6 files changed, 252 insertions(+), 34 deletions(-) create mode 100644 src/transformers/models/superpoint/image_processing_superpoint_fast.py diff --git a/docs/source/en/model_doc/superpoint.md b/docs/source/en/model_doc/superpoint.md index 31f40e5a37..27ab95ac67 100644 --- a/docs/source/en/model_doc/superpoint.md +++ b/docs/source/en/model_doc/superpoint.md @@ -130,6 +130,11 @@ processed_outputs = processor.post_process_keypoint_detection(outputs, [image_si [[autodoc]] SuperPointImageProcessor +- preprocess + +## SuperPointImageProcessorFast + +[[autodoc]] SuperPointImageProcessorFast - preprocess - post_process_keypoint_detection diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 0a0cc6a38c..527b25c10a 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -162,6 +162,13 @@ else: ("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")), ("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")), ("superglue", ("SuperGlueImageProcessor",)), + ( + "superpoint", + ( + "SuperPointImageProcessor", + "SuperPointImageProcessorFast", + ), + ), ("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin", ("ViTImageProcessor", "ViTImageProcessorFast")), ("swin2sr", ("Swin2SRImageProcessor", "Swin2SRImageProcessorFast")), diff --git a/src/transformers/models/superpoint/__init__.py b/src/transformers/models/superpoint/__init__.py index aab40abaa8..ccec260fa5 100644 --- a/src/transformers/models/superpoint/__init__.py +++ b/src/transformers/models/superpoint/__init__.py @@ -20,6 +20,7 @@ from ...utils.import_utils import define_import_structure if TYPE_CHECKING: from .configuration_superpoint import * from .image_processing_superpoint import * + from .image_processing_superpoint_fast import * from .modeling_superpoint import * else: import sys diff --git a/src/transformers/models/superpoint/image_processing_superpoint.py b/src/transformers/models/superpoint/image_processing_superpoint.py index 9759e751b2..e6167160c2 100644 --- a/src/transformers/models/superpoint/image_processing_superpoint.py +++ b/src/transformers/models/superpoint/image_processing_superpoint.py @@ -23,6 +23,7 @@ from ...image_transforms import resize, to_channel_dimension_format from ...image_utils import ( ChannelDimension, ImageInput, + PILImageResampling, infer_channel_dimension_format, is_scaled_image, make_list_of_images, @@ -107,6 +108,8 @@ class SuperPointImageProcessor(BaseImageProcessor): size (`dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`): Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to `True`. Can be overridden by `size` in the `preprocess` method. + resample (`Resampling`, *optional*, defaults to `2`): + Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method. do_rescale (`bool`, *optional*, defaults to `True`): Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in the `preprocess` method. @@ -123,6 +126,7 @@ class SuperPointImageProcessor(BaseImageProcessor): self, do_resize: bool = True, size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, do_rescale: bool = True, rescale_factor: float = 1 / 255, do_grayscale: bool = False, @@ -134,6 +138,7 @@ class SuperPointImageProcessor(BaseImageProcessor): self.do_resize = do_resize self.size = size + self.resample = resample self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_grayscale = do_grayscale @@ -182,6 +187,7 @@ class SuperPointImageProcessor(BaseImageProcessor): images, do_resize: Optional[bool] = None, size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, do_rescale: Optional[bool] = None, rescale_factor: Optional[float] = None, do_grayscale: Optional[bool] = None, @@ -231,6 +237,7 @@ class SuperPointImageProcessor(BaseImageProcessor): """ do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_grayscale = do_grayscale if do_grayscale is not None else self.do_grayscale @@ -266,7 +273,10 @@ class SuperPointImageProcessor(BaseImageProcessor): input_data_format = infer_channel_dimension_format(images[0]) if do_resize: - images = [self.resize(image=image, size=size, input_data_format=input_data_format) for image in images] + images = [ + self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + for image in images + ] if do_rescale: images = [ diff --git a/src/transformers/models/superpoint/image_processing_superpoint_fast.py b/src/transformers/models/superpoint/image_processing_superpoint_fast.py new file mode 100644 index 0000000000..e70bb397ff --- /dev/null +++ b/src/transformers/models/superpoint/image_processing_superpoint_fast.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for Superpoint.""" + +from typing import TYPE_CHECKING, Optional, Union + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + PILImageResampling, + SizeDict, +) +from ...processing_utils import Unpack +from ...utils import ( + TensorType, + auto_docstring, + is_torch_available, + is_torchvision_available, + is_torchvision_v2_available, +) + + +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .modeling_superpoint import SuperPointKeypointDescriptionOutput + +if is_torchvision_v2_available(): + import torchvision.transforms.v2.functional as F +elif is_torchvision_available(): + import torchvision.transforms.functional as F + + +def is_grayscale( + image: "torch.Tensor", +): + """Checks if an image is grayscale (all RGB channels are identical).""" + if image.ndim < 3 or image.shape[0 if image.ndim == 3 else 1] == 1: + return True + return torch.all(image[..., 0, :, :] == image[..., 1, :, :]) and torch.all( + image[..., 1, :, :] == image[..., 2, :, :] + ) + + +class SuperPointFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): + r""" + do_grayscale (`bool`, *optional*, defaults to `True`): + Whether to convert the image to grayscale. Can be overridden by `do_grayscale` in the `preprocess` method. + """ + + do_grayscale: Optional[bool] = True + + +def convert_to_grayscale( + image: "torch.Tensor", +) -> "torch.Tensor": + """ + Converts an image to grayscale format using the NTSC formula. Only support torch.Tensor. + + This function is supposed to return a 1-channel image, but it returns a 3-channel image with the same value in each + channel, because of an issue that is discussed in : + https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446 + + Args: + image (torch.Tensor): + The image to convert. + """ + if is_grayscale(image): + return image + return F.rgb_to_grayscale(image, num_output_channels=3) + + +@auto_docstring +class SuperPointImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + size = {"height": 480, "width": 640} + default_to_square = False + do_resize = True + do_rescale = True + rescale_factor = 1 / 255 + do_normalize = None + valid_kwargs = SuperPointFastImageProcessorKwargs + + def __init__(self, **kwargs: Unpack[SuperPointFastImageProcessorKwargs]): + super().__init__(**kwargs) + + def _preprocess( + self, + images: list["torch.Tensor"], + size: Union[dict[str, int], SizeDict], + rescale_factor: float, + do_rescale: bool, + do_resize: bool, + interpolation: Optional["F.InterpolationMode"], + do_grayscale: bool, + disable_grouping: bool, + return_tensors: Union[str, TensorType], + **kwargs, + ) -> BatchFeature: + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + processed_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + if do_grayscale: + stacked_images = convert_to_grayscale(stacked_images) + if do_resize: + stacked_images = self.resize(stacked_images, size=size, interpolation=interpolation) + if do_rescale: + stacked_images = self.rescale(stacked_images, rescale_factor) + processed_images_grouped[shape] = stacked_images + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images + return BatchFeature(data={"pixel_values": processed_images}) + + def post_process_keypoint_detection( + self, outputs: "SuperPointKeypointDescriptionOutput", target_sizes: Union[TensorType, list[tuple]] + ) -> list[dict[str, "torch.Tensor"]]: + """ + Converts the raw output of [`SuperPointForKeypointDetection`] into lists of keypoints, scores and descriptors + with coordinates absolute to the original image sizes. + + Args: + outputs ([`SuperPointKeypointDescriptionOutput`]): + Raw outputs of the model containing keypoints in a relative (x, y) format, with scores and descriptors. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. This must be the original + image size (before any processing). + Returns: + `List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in absolute format according + to target_sizes, scores and descriptors for an image in the batch as predicted by the model. + """ + if len(outputs.mask) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask") + + if isinstance(target_sizes, list): + image_sizes = torch.tensor(target_sizes, device=outputs.mask.device) + else: + if target_sizes.shape[1] != 2: + raise ValueError( + "Each element of target_sizes must contain the size (h, w) of each image of the batch" + ) + image_sizes = target_sizes + + # Flip the image sizes to (width, height) and convert keypoints to absolute coordinates + image_sizes = torch.flip(image_sizes, [1]) + masked_keypoints = outputs.keypoints * image_sizes[:, None] + + # Convert masked_keypoints to int + masked_keypoints = masked_keypoints.to(torch.int32) + + results = [] + for image_mask, keypoints, scores, descriptors in zip( + outputs.mask, masked_keypoints, outputs.scores, outputs.descriptors + ): + indices = torch.nonzero(image_mask).squeeze(1) + keypoints = keypoints[indices] + scores = scores[indices] + descriptors = descriptors[indices] + results.append({"keypoints": keypoints, "scores": scores, "descriptors": descriptors}) + + return results + + +__all__ = ["SuperPointImageProcessorFast"] diff --git a/tests/models/superpoint/test_image_processing_superpoint.py b/tests/models/superpoint/test_image_processing_superpoint.py index 38a90ca63a..242e3702be 100644 --- a/tests/models/superpoint/test_image_processing_superpoint.py +++ b/tests/models/superpoint/test_image_processing_superpoint.py @@ -16,12 +16,9 @@ import unittest import numpy as np from transformers.testing_utils import require_torch, require_vision -from transformers.utils import is_torch_available, is_vision_available +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available -from ...test_image_processing_common import ( - ImageProcessingTestMixin, - prepare_image_inputs, -) +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs if is_torch_available(): @@ -32,6 +29,9 @@ if is_torch_available(): if is_vision_available(): from transformers import SuperPointImageProcessor + if is_torchvision_available(): + from transformers import SuperPointImageProcessorFast + class SuperPointImageProcessingTester: def __init__( @@ -100,6 +100,7 @@ class SuperPointImageProcessingTester: @require_vision class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image_processing_class = SuperPointImageProcessor if is_vision_available() else None + fast_image_processing_class = SuperPointImageProcessorFast if is_torchvision_available() else None def setUp(self) -> None: super().setUp() @@ -110,40 +111,44 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) return self.image_processor_tester.prepare_image_processor_dict() def test_image_processing(self): - image_processing = self.image_processing_class(**self.image_processor_dict) - self.assertTrue(hasattr(image_processing, "do_resize")) - self.assertTrue(hasattr(image_processing, "size")) - self.assertTrue(hasattr(image_processing, "do_rescale")) - self.assertTrue(hasattr(image_processing, "rescale_factor")) - self.assertTrue(hasattr(image_processing, "do_grayscale")) + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_rescale")) + self.assertTrue(hasattr(image_processing, "rescale_factor")) + self.assertTrue(hasattr(image_processing, "do_grayscale")) def test_image_processor_from_dict_with_kwargs(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - self.assertEqual(image_processor.size, {"height": 480, "width": 640}) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 480, "width": 640}) - image_processor = self.image_processing_class.from_dict( - self.image_processor_dict, size={"height": 42, "width": 42} - ) - self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + image_processor = self.image_processing_class.from_dict( + self.image_processor_dict, size={"height": 42, "width": 42} + ) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) @unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image") def test_call_numpy_4_channels(self): pass def test_input_image_properly_converted_to_grayscale(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs() - pre_processed_images = image_processor.preprocess(image_inputs) - for image in pre_processed_images["pixel_values"]: - self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs() + pre_processed_images = image_processor.preprocess(image_inputs) + for image in pre_processed_images["pixel_values"]: + if isinstance(image, torch.Tensor): + self.assertTrue( + torch.all(image[0, ...] == image[1, ...]).item() + and torch.all(image[1, ...] == image[2, ...]).item() + ) + else: + self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...])) @require_torch def test_post_processing_keypoint_detection(self): - image_processor = self.image_processing_class.from_dict(self.image_processor_dict) - image_inputs = self.image_processor_tester.prepare_image_inputs() - pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt") - outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images) - def check_post_processed_output(post_processed_output, image_size): for post_processed_output, image_size in zip(post_processed_output, image_size): self.assertTrue("keypoints" in post_processed_output) @@ -157,12 +162,20 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase) self.assertTrue(all_below_image_size) self.assertTrue(all_above_zero) - tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs] - tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes) + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + image_inputs = self.image_processor_tester.prepare_image_inputs() + pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt") + outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images) - check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes) + tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs] + tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes) - tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1) - tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes) + check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes) - check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes) + tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1) + tensor_post_processed_outputs = image_processor.post_process_keypoint_detection( + outputs, tensor_image_sizes + ) + + check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)