Fix Pan and Scan on batched images Gemma3 (#36864)

* process flattened images in fast image proc

* process flattened images in low proc and add tests

* remove print

* add unbalanced batch test pas image proc

* fix integration tests
This commit is contained in:
Yoni Gozlan
2025-03-21 13:56:00 -04:00
committed by GitHub
parent dd3933dd65
commit beb9b5b022
5 changed files with 185 additions and 130 deletions

View File

@@ -189,6 +189,13 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched unbalanced, 9 images because we have base image + 2 crops per each item
encoded_images = image_processing(
[[image_inputs[0], image_inputs[1]], [image_inputs[2]]], return_tensors="pt"
).pixel_values
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
@@ -250,3 +257,37 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
def test_call_numpy_4_channels(self):
pass
@require_vision
@require_torch
def test_slow_fast_equivalence_batched_pas(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
crop_config = {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
image_processor_dict = self.image_processor_dict
image_processor_dict.update(crop_config)
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)