Superpoint fast image processor (#37804)

* feat: superpoint fast image processor

* fix: reran fast cli command to generate fast config

* feat: updated test cases

* fix: removed old model add

* fix: format fix

* Update src/transformers/models/superpoint/image_processing_superpoint_fast.py

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>

* fix: ported to torch and made requested changes

* fix: removed changes to init

* fix: init fix

* fix: init format fix

* fixed testcases and ported to torch

* fix: format fixes

* failed
test case fix

* fix superpoint fast

* fix docstring

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
This commit is contained in:
Avigyan Sinha
2025-07-28 23:45:06 +05:30
committed by GitHub
parent 14adcbd937
commit c353f2bb5e
6 changed files with 252 additions and 34 deletions

View File

@@ -16,12 +16,9 @@ import unittest
import numpy as np
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import (
ImageProcessingTestMixin,
prepare_image_inputs,
)
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
@@ -32,6 +29,9 @@ if is_torch_available():
if is_vision_available():
from transformers import SuperPointImageProcessor
if is_torchvision_available():
from transformers import SuperPointImageProcessorFast
class SuperPointImageProcessingTester:
def __init__(
@@ -100,6 +100,7 @@ class SuperPointImageProcessingTester:
@require_vision
class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SuperPointImageProcessor if is_vision_available() else None
fast_image_processing_class = SuperPointImageProcessorFast if is_torchvision_available() else None
def setUp(self) -> None:
super().setUp()
@@ -110,40 +111,44 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processing(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_grayscale"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_grayscale"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 42, "width": 42}
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 42, "width": 42}
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
@unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image")
def test_call_numpy_4_channels(self):
pass
def test_input_image_properly_converted_to_grayscale(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs)
for image in pre_processed_images["pixel_values"]:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs)
for image in pre_processed_images["pixel_values"]:
if isinstance(image, torch.Tensor):
self.assertTrue(
torch.all(image[0, ...] == image[1, ...]).item()
and torch.all(image[1, ...] == image[2, ...]).item()
)
else:
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
@require_torch
def test_post_processing_keypoint_detection(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
def check_post_processed_output(post_processed_output, image_size):
for post_processed_output, image_size in zip(post_processed_output, image_size):
self.assertTrue("keypoints" in post_processed_output)
@@ -157,12 +162,20 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
self.assertTrue(all_below_image_size)
self.assertTrue(all_above_zero)
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(
outputs, tensor_image_sizes
)
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)