Add Fast Image Processor for PoolFormer (#37182)

* support poolformer fast image processor

* support test for crop_pct=None

* run make style

* Apply suggestions from code review

* rename test

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Vinh H. Pham
2025-04-24 02:55:33 +07:00
committed by GitHub
parent b491f128d6
commit dea1919be4
5 changed files with 305 additions and 15 deletions

View File

@@ -15,7 +15,7 @@
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from transformers.utils import is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@@ -23,6 +23,9 @@ from ...test_image_processing_common import ImageProcessingTestMixin, prepare_im
if is_vision_available():
from transformers import PoolFormerImageProcessor
if is_torchvision_available():
from transformers import PoolFormerImageProcessorFast
class PoolFormerImageProcessingTester:
def __init__(
@@ -85,6 +88,7 @@ class PoolFormerImageProcessingTester:
@require_vision
class PoolFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = PoolFormerImageProcessor if is_vision_available() else None
fast_image_processing_class = PoolFormerImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
@@ -95,19 +99,29 @@ class PoolFormerImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize_and_center_crop"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "crop_pct"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize_and_center_crop"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "crop_pct"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 30})
self.assertEqual(image_processor.crop_size, {"height": 30, "width": 30})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 30})
self.assertEqual(image_processor.crop_size, {"height": 30, "width": 30})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
image_processor = self.image_processing_class.from_dict(self.image_processor_dict, size=42, crop_size=84)
self.assertEqual(image_processor.size, {"shortest_edge": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
@require_torch
@require_vision
class PoolFormerImageProcessingNoCropPctTest(PoolFormerImageProcessingTest):
def setUp(self):
super().setUp()
self.image_processor_tester = PoolFormerImageProcessingTester(self, crop_pct=None)