Add Fast Image Processor for vilt (#37304)

* init vilt image processor fast

* Refactor image processor tests to use loop for all processors

* Add ViltImageProcessorFast with PyTorch-based optimized image processing

* Change made automatically by make fixup command

* Change made automatically by make fix-copies command

* Fix type hints in ViltImageProcessorFast for Python compatibility

* Define constants for image resizing based on COCO dataset aspect ratio

* Add missing property initializations to ViltImageProcessorFast

* Extract resize logic into dedicated method in ViltImageProcessorFast

* Extract padding logic into dedicated method

* Implement shape-based image grouping for optimized processing in Vilt

* Update test suite to verify ViltImageProcessorFast attributes

* Move variable declarations to _preprocess method parameters

* Remove unused parameters

* Rename _resize method to resize to override existing function

* Remove whitespace

* Remove unnecessary type check and conversion for stacked_images

* Remove redundant loop and apply padding directly to stacked images

* Refactor pad function to return images and mask as tuple instead of dict

* Add tests comparing padding masks in slow and fast implementations

* Update ViltImageProcessor tests to ensure compatibility between slow and fast implementations

* Replace add_start_docstrings with auto_docstring in ViltImageProcessorFast

* Move docstrings of custom args to ViltFastImageProcessorKwargs

* Use reorder_images function for both masks and images

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Jinyong Lee
2025-05-14 00:40:53 +09:00
committed by GitHub
parent 8771766a70
commit 342961f669
5 changed files with 309 additions and 13 deletions

View File

@@ -16,9 +16,10 @@
import unittest
import numpy as np
import torch
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@@ -28,6 +29,9 @@ if is_vision_available():
from transformers import ViltImageProcessor
if is_torchvision_available():
from transformers import ViltImageProcessorFast
class ViltImageProcessingTester:
def __init__(
@@ -131,6 +135,7 @@ class ViltImageProcessingTester:
@require_vision
class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = ViltImageProcessor if is_vision_available() else None
fast_image_processing_class = ViltImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@@ -141,17 +146,43 @@ class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "size_divisor"))
self.assertTrue(hasattr(image_processing, "do_pad"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "model_input_names"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 30})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 30})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
def test_slow_fast_equivalence(self):
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
image_processor_slow = self.image_processing_class(**self.image_processor_dict, do_pad=True)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, do_pad=True)
slow_outputs = image_processor_slow(image_inputs, return_tensors="pt")
slow_pixel_values = slow_outputs.pixel_values
slow_pixel_mask = slow_outputs.pixel_mask
fast_outputs = image_processor_fast(image_inputs, return_tensors="pt")
fast_pixel_values = fast_outputs.pixel_values
fast_pixel_mask = fast_outputs.pixel_mask
self.assertEqual(slow_pixel_values.shape, fast_pixel_values.shape)
self.assertTrue(torch.allclose(slow_pixel_values, fast_pixel_values, atol=1e-2))
self.assertEqual(slow_pixel_mask.shape, fast_pixel_mask.shape)
self.assertTrue(torch.equal(slow_pixel_mask, fast_pixel_mask))