Add Fast Image Processor for Flava (#37135)

* support flava fast image processor

* run style and quality

* update test

* update according to reviews

* make style

* update comment on BICUBIC

* make style

---------

Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
This commit is contained in:
Vinh H. Pham
2025-04-14 20:05:31 +07:00
committed by GitHub
parent a5079a2c84
commit 49b9a69a36
5 changed files with 762 additions and 160 deletions

View File

@@ -16,9 +16,11 @@ import random
import unittest
import numpy as np
import requests
from PIL import Image
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_vision_available
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
@@ -30,6 +32,9 @@ if is_vision_available():
import PIL
from transformers import FlavaImageProcessor
if is_torchvision_available():
from transformers import FlavaImageProcessorFast
from transformers.image_utils import PILImageResampling
from transformers.models.flava.image_processing_flava import (
FLAVA_CODEBOOK_MEAN,
@@ -105,7 +110,8 @@ class FlavaImageProcessingTester:
self.codebook_do_resize = codebook_do_resize
self.codebook_size = codebook_size
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.LANCZOS
# LANCZOS resample does not support torch Tensor. Use BICUBIC as closest alternative
self.codebook_resample = codebook_resample if codebook_resample is not None else PILImageResampling.BICUBIC
self.codebook_do_center_crop = codebook_do_center_crop
self.codebook_crop_size = codebook_crop_size
self.codebook_do_map_pixels = codebook_do_map_pixels
@@ -171,6 +177,7 @@ class FlavaImageProcessingTester:
@require_vision
class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = FlavaImageProcessor if is_vision_available() else None
fast_image_processing_class = FlavaImageProcessorFast if is_torchvision_available() else None
maxDiff = None
def setUp(self):
@@ -182,157 +189,161 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "crop_size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "masking_generator"))
self.assertTrue(hasattr(image_processing, "codebook_do_resize"))
self.assertTrue(hasattr(image_processing, "codebook_size"))
self.assertTrue(hasattr(image_processing, "codebook_resample"))
self.assertTrue(hasattr(image_processing, "codebook_do_center_crop"))
self.assertTrue(hasattr(image_processing, "codebook_crop_size"))
self.assertTrue(hasattr(image_processing, "codebook_do_map_pixels"))
self.assertTrue(hasattr(image_processing, "codebook_do_normalize"))
self.assertTrue(hasattr(image_processing, "codebook_image_mean"))
self.assertTrue(hasattr(image_processing, "codebook_image_std"))
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "crop_size"))
self.assertTrue(hasattr(image_processing, "do_center_crop"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "masking_generator"))
self.assertTrue(hasattr(image_processing, "codebook_do_resize"))
self.assertTrue(hasattr(image_processing, "codebook_size"))
self.assertTrue(hasattr(image_processing, "codebook_resample"))
self.assertTrue(hasattr(image_processing, "codebook_do_center_crop"))
self.assertTrue(hasattr(image_processing, "codebook_crop_size"))
self.assertTrue(hasattr(image_processing, "codebook_do_map_pixels"))
self.assertTrue(hasattr(image_processing, "codebook_do_normalize"))
self.assertTrue(hasattr(image_processing, "codebook_image_mean"))
self.assertTrue(hasattr(image_processing, "codebook_image_std"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 224, "width": 224})
self.assertEqual(image_processor.crop_size, {"height": 224, "width": 224})
self.assertEqual(image_processor.codebook_size, {"height": 112, "width": 112})
self.assertEqual(image_processor.codebook_crop_size, {"height": 112, "width": 112})
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 224, "width": 224})
self.assertEqual(image_processor.crop_size, {"height": 224, "width": 224})
self.assertEqual(image_processor.codebook_size, {"height": 112, "width": 112})
self.assertEqual(image_processor.codebook_crop_size, {"height": 112, "width": 112})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.codebook_size, {"height": 33, "width": 33})
self.assertEqual(image_processor.codebook_crop_size, {"height": 66, "width": 66})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size=42, crop_size=84, codebook_size=33, codebook_crop_size=66
)
self.assertEqual(image_processor.size, {"height": 42, "width": 42})
self.assertEqual(image_processor.crop_size, {"height": 84, "width": 84})
self.assertEqual(image_processor.codebook_size, {"height": 33, "width": 33})
self.assertEqual(image_processor.codebook_crop_size, {"height": 66, "width": 66})
def test_call_pil(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
# Test no bool masked pos
self.assertFalse("bool_masked_pos" in encoded_images)
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
def _test_call_framework(self, instance_class, prepare_kwargs):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, **prepare_kwargs)
for image in image_inputs:
self.assertIsInstance(image, instance_class)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, **prepare_kwargs)
for image in image_inputs:
self.assertIsInstance(image, instance_class)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
# Test masking
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
# Test masking
encoded_images = image_processing(image_inputs, return_image_mask=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_image_size()
self.assertEqual(
encoded_images.pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
expected_height, expected_width = self.image_processor_tester.get_expected_mask_size()
self.assertEqual(
encoded_images.bool_masked_pos.shape,
(
self.image_processor_tester.batch_size,
expected_height,
expected_width,
),
)
def test_call_numpy(self):
self._test_call_framework(np.ndarray, prepare_kwargs={"numpify": True})
@@ -346,40 +357,76 @@ class FlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self._test_call_framework(torch.Tensor, prepare_kwargs={"torchify": True})
def test_masking(self):
# Initialize image_processing
random.seed(1234)
image_processing = self.image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
random.seed(1234)
image_processing = image_processing_class(**self.image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_image_mask=True, return_tensors="pt")
self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_image_mask=True, return_tensors="pt")
self.assertEqual(encoded_images.bool_masked_pos.sum().item(), 75)
def test_codebook_pixels(self):
# Initialize image_processing
image_processing = self.image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, PIL.Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(1, self.image_processor_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = image_processing(image_inputs, return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
)
@require_vision
@require_torch
def test_slow_fast_equivalence(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
dummy_image = Image.open(
requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw
)
image_processor_slow = self.image_processing_class(**self.image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)
encoding_slow = image_processor_slow(
dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True
)
encoding_fast = image_processor_fast(
dummy_image, return_tensors="pt", return_codebook_pixels=True, return_image_mask=True
)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)
# Test batched
encoded_images = image_processing(image_inputs, return_codebook_pixels=True, return_tensors="pt")
expected_height, expected_width = self.image_processor_tester.get_expected_codebook_image_size()
self.assertEqual(
encoded_images.codebook_pixel_values.shape,
(
self.image_processor_tester.batch_size,
self.image_processor_tester.num_channels,
expected_height,
expected_width,
),
self.assertTrue(
torch.allclose(encoding_slow.codebook_pixel_values, encoding_fast.codebook_pixel_values, atol=1e-1)
)
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.codebook_pixel_values - encoding_fast.codebook_pixel_values)).item(),
1e-3,
)