Add Fast Image Processor for vilt (#37304)
* init vilt image processor fast * Refactor image processor tests to use loop for all processors * Add ViltImageProcessorFast with PyTorch-based optimized image processing * Change made automatically by make fixup command * Change made automatically by make fix-copies command * Fix type hints in ViltImageProcessorFast for Python compatibility * Define constants for image resizing based on COCO dataset aspect ratio * Add missing property initializations to ViltImageProcessorFast * Extract resize logic into dedicated method in ViltImageProcessorFast * Extract padding logic into dedicated method * Implement shape-based image grouping for optimized processing in Vilt * Update test suite to verify ViltImageProcessorFast attributes * Move variable declarations to _preprocess method parameters * Remove unused parameters * Rename _resize method to resize to override existing function * Remove whitespace * Remove unnecessary type check and conversion for stacked_images * Remove redundant loop and apply padding directly to stacked images * Refactor pad function to return images and mask as tuple instead of dict * Add tests comparing padding masks in slow and fast implementations * Update ViltImageProcessor tests to ensure compatibility between slow and fast implementations * Replace add_start_docstrings with auto_docstring in ViltImageProcessorFast * Move docstrings of custom args to ViltFastImageProcessorKwargs * Use reorder_images function for both masks and images --------- Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
@@ -16,9 +16,10 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision
|
||||
from transformers.utils import is_vision_available
|
||||
from transformers.utils import is_torchvision_available, is_vision_available
|
||||
|
||||
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
|
||||
|
||||
@@ -28,6 +29,9 @@ if is_vision_available():
|
||||
|
||||
from transformers import ViltImageProcessor
|
||||
|
||||
if is_torchvision_available():
|
||||
from transformers import ViltImageProcessorFast
|
||||
|
||||
|
||||
class ViltImageProcessingTester:
|
||||
def __init__(
|
||||
@@ -131,6 +135,7 @@ class ViltImageProcessingTester:
|
||||
@require_vision
|
||||
class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing_class = ViltImageProcessor if is_vision_available() else None
|
||||
fast_image_processing_class = ViltImageProcessorFast if is_torchvision_available() else None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@@ -141,17 +146,43 @@ class ViltImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
return self.image_processor_tester.prepare_image_processor_dict()
|
||||
|
||||
def test_image_processor_properties(self):
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
self.assertTrue(hasattr(image_processing, "image_mean"))
|
||||
self.assertTrue(hasattr(image_processing, "image_std"))
|
||||
self.assertTrue(hasattr(image_processing, "do_normalize"))
|
||||
self.assertTrue(hasattr(image_processing, "do_resize"))
|
||||
self.assertTrue(hasattr(image_processing, "size"))
|
||||
self.assertTrue(hasattr(image_processing, "size_divisor"))
|
||||
self.assertTrue(hasattr(image_processing, "do_pad"))
|
||||
self.assertTrue(hasattr(image_processing, "resample"))
|
||||
self.assertTrue(hasattr(image_processing, "do_rescale"))
|
||||
self.assertTrue(hasattr(image_processing, "model_input_names"))
|
||||
|
||||
def test_image_processor_from_dict_with_kwargs(self):
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 30})
|
||||
for image_processing_class in self.image_processor_list:
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 30})
|
||||
|
||||
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42)
|
||||
self.assertEqual(image_processor.size, {"shortest_edge": 42})
|
||||
|
||||
def test_slow_fast_equivalence(self):
|
||||
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
|
||||
|
||||
image_processor_slow = self.image_processing_class(**self.image_processor_dict, do_pad=True)
|
||||
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict, do_pad=True)
|
||||
|
||||
slow_outputs = image_processor_slow(image_inputs, return_tensors="pt")
|
||||
slow_pixel_values = slow_outputs.pixel_values
|
||||
slow_pixel_mask = slow_outputs.pixel_mask
|
||||
|
||||
fast_outputs = image_processor_fast(image_inputs, return_tensors="pt")
|
||||
fast_pixel_values = fast_outputs.pixel_values
|
||||
fast_pixel_mask = fast_outputs.pixel_mask
|
||||
|
||||
self.assertEqual(slow_pixel_values.shape, fast_pixel_values.shape)
|
||||
self.assertTrue(torch.allclose(slow_pixel_values, fast_pixel_values, atol=1e-2))
|
||||
|
||||
self.assertEqual(slow_pixel_mask.shape, fast_pixel_mask.shape)
|
||||
self.assertTrue(torch.equal(slow_pixel_mask, fast_pixel_mask))
|
||||
|
||||
Reference in New Issue
Block a user