Superpoint fast image processor (#37804)
* feat: superpoint fast image processor * fix: reran fast cli command to generate fast config * feat: updated test cases * fix: removed old model add * fix: format fix * Update src/transformers/models/superpoint/image_processing_superpoint_fast.py Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * fix: ported to torch and made requested changes * fix: removed changes to init * fix: init fix * fix: init format fix * fixed testcases and ported to torch * fix: format fixes * failed test case fix * fix superpoint fast * fix docstring --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
This commit is contained in:
@@ -16,12 +16,9 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import (
|
||||
ImageProcessingTestMixin,
|
||||
prepare_image_inputs,
|
||||
)
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -32,6 +29,9 @@ if is_torch_available():
|
||||
if is_vision_available():
|
||||
from transformers import SuperPointImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import SuperPointImageProcessorFast
|
||||
|
||||
|
||||
class SuperPointImageProcessingTester:
|
||||
def __init__(
|
||||
@@ -100,6 +100,7 @@ class SuperPointImageProcessingTester:
|
||||
@require_vision
|
||||
class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = SuperPointImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = SuperPointImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
@@ -110,40 +111,44 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processing(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_grayscale"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "rescale_factor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_grayscale"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"height": 480, "width": 640})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size={"height": 42, "width": 42}
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
image_processor = self.image_processing_class.from_dict(
|
||||
self.image_processor_dict, size={"height": 42, "width": 42}
|
||||
)
|
||||
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
|
||||
|
||||
@unittest.skip(reason="SuperPointImageProcessor is always supposed to return a grayscaled image")
|
||||
def test_call_numpy_4_channels(self):
|
||||
pass
|
||||
|
||||
def test_input_image_properly_converted_to_grayscale(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs)
|
||||
for image in pre_processed_images["pixel_values"]:
|
||||
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs)
|
||||
for image in pre_processed_images["pixel_values"]:
|
||||
if isinstance(image, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.all(image[0, ...] == image[1, ...]).item()
|
||||
and torch.all(image[1, ...] == image[2, ...]).item()
|
||||
)
|
||||
else:
|
||||
self.assertTrue(np.all(image[0, ...] == image[1, ...]) and np.all(image[1, ...] == image[2, ...]))
|
||||
|
||||
@require_torch
|
||||
def test_post_processing_keypoint_detection(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
|
||||
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
|
||||
|
||||
def check_post_processed_output(post_processed_output, image_size):
|
||||
for post_processed_output, image_size in zip(post_processed_output, image_size):
|
||||
self.assertTrue("keypoints" in post_processed_output)
|
||||
@@ -157,12 +162,20 @@ class SuperPointImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
||||
self.assertTrue(all_below_image_size)
|
||||
self.assertTrue(all_above_zero)
|
||||
|
||||
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
|
||||
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs()
|
||||
pre_processed_images = image_processor.preprocess(image_inputs, return_tensors="pt")
|
||||
outputs = self.image_processor_tester.prepare_keypoint_detection_output(**pre_processed_images)
|
||||
|
||||
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
|
||||
tuple_image_sizes = [(image.size[0], image.size[1]) for image in image_inputs]
|
||||
tuple_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tuple_image_sizes)
|
||||
|
||||
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
|
||||
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(outputs, tensor_image_sizes)
|
||||
check_post_processed_output(tuple_post_processed_outputs, tuple_image_sizes)
|
||||
|
||||
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|
||||
tensor_image_sizes = torch.tensor([image.size for image in image_inputs]).flip(1)
|
||||
tensor_post_processed_outputs = image_processor.post_process_keypoint_detection(
|
||||
outputs, tensor_image_sizes
|
||||
)
|
||||
|
||||
check_post_processed_output(tensor_post_processed_outputs, tensor_image_sizes)
|
||||
|
||||
Reference in New Issue
Block a user