Fix Pan and Scan on batched images Gemma3 (#36864)

* process flattened images in fast image proc

* process flattened images in low proc and add tests

* remove print

* add unbalanced batch test pas image proc

* fix integration tests
This commit is contained in:
Yoni Gozlan
2025-03-21 13:56:00 -04:00
committed by GitHub
parent dd3933dd65
commit beb9b5b022
5 changed files with 185 additions and 130 deletions

View File

@@ -189,6 +189,13 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched unbalanced, 9 images because we have base image + 2 crops per each item
encoded_images = image_processing(
[[image_inputs[0], image_inputs[1]], [image_inputs[2]]], return_tensors="pt"
).pixel_values
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
@@ -250,3 +257,37 @@ class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
def test_call_numpy_4_channels(self):
pass
@require_vision
@require_torch
def test_slow_fast_equivalence_batched_pas(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
crop_config = {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
image_processor_dict = self.image_processor_dict
image_processor_dict.update(crop_config)
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
self.assertTrue(torch.allclose(encoding_slow.pixel_values, encoding_fast.pixel_values, atol=1e-1))
self.assertLessEqual(
torch.mean(torch.abs(encoding_slow.pixel_values - encoding_fast.pixel_values)).item(), 1e-3
)

View File

@@ -395,7 +395,7 @@ class Gemma3IntegrationTest(unittest.TestCase):
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'] # fmt: skip
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_batch(self):
@@ -467,7 +467,56 @@ class Gemma3IntegrationTest(unittest.TestCase):
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background.'] # fmt: skip
EXPECTED_TEXTS = ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background. It looks like the cow is enjoying the beach'] # fmt: skip
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
self.assertEqual(output_text, EXPECTED_TEXTS)
def test_model_4b_batch_crops(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).to(torch_device)
crop_config = {
"images_kwargs": {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
}
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
**crop_config,
).to(torch_device)
output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
EXPECTED_TEXTS = [
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a beach with a turquoise ocean and blue sky in the background. It looks like the cow is enjoying the beach",
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nWhile they all feature a brown cow in the foreground and a similar background (including the stop signs and",
] # fmt: skip
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
self.assertEqual(output_text, EXPECTED_TEXTS)