From d8663cb8c50322f351b73132cf8f02f14cf5aeea Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Tue, 11 Mar 2025 03:22:48 -0600 Subject: [PATCH] Fix bugs in mllama image processing (#36156) * fix: handle input_channel_dim == channels_last Signed-off-by: Travis Johnson * fix: default PIL images to channels_last Signed-off-by: Travis Johnson * Apply suggestions from code review Co-authored-by: Pavel Iakubovskii * fixup from review batch Signed-off-by: Travis Johnson * test: add 1x1 PIL image to ambiguous channel test Signed-off-by: Travis Johnson * fix(mllama): avoid 0 dimension for image with impractical aspect ratio Signed-off-by: Travis Johnson --------- Signed-off-by: Travis Johnson Co-authored-by: Pavel Iakubovskii --- .../models/mllama/image_processing_mllama.py | 21 ++++++++----- .../mllama/test_image_processing_mllama.py | 30 +++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mllama/image_processing_mllama.py b/src/transformers/models/mllama/image_processing_mllama.py index 9ff077f150..bcb97bbcd9 100644 --- a/src/transformers/models/mllama/image_processing_mllama.py +++ b/src/transformers/models/mllama/image_processing_mllama.py @@ -93,7 +93,7 @@ def get_image_size_fit_to_canvas( canvas_height and canvas_width, while ensuring that the image dimensions are not smaller than tile_size. If the image is larger than the canvas, the returned size will fit within the canvas. If the image already fits within the canvas, the size remains unchanged. - The aspect ratio of the original image is preserved. + The aspect ratio of the original image is preserved as much as possible. Args: image_height (`int`): @@ -120,10 +120,12 @@ def get_image_size_fit_to_canvas( if scale_w < scale_h: new_width = target_width - new_height = min(math.floor(image_height * scale_w), target_height) + # minimum height is 1 to avoid invalid height of 0 + new_height = min(math.floor(image_height * scale_w) or 1, target_height) else: new_height = target_height - new_width = min(math.floor(image_width * scale_h), target_width) + # minimum width is 1 to avoid invalid width of 0 + new_width = min(math.floor(image_width * scale_h) or 1, target_width) return new_height, new_width @@ -695,8 +697,6 @@ class MllamaImageProcessor(BaseImageProcessor): if self.do_convert_rgb: images_list = [[convert_to_rgb(image) for image in images] for images in images_list] - images_list = [[to_numpy_array(image) for image in images] for images in images_list] - batch_images = [] batch_aspect_ratios = [] @@ -707,6 +707,13 @@ class MllamaImageProcessor(BaseImageProcessor): # iterate over images in a batch sample for image in images: + # default PIL images to channels_last + if input_data_format is None and isinstance(image, PIL.Image.Image): + input_data_format = ChannelDimension.LAST + + # convert to numpy array for processing + image = to_numpy_array(image) + # convert images to channels first format for faster processing # LAST is slower for `pad` and not supported by `split_to_tiles` data_format = ChannelDimension.FIRST @@ -735,7 +742,7 @@ class MllamaImageProcessor(BaseImageProcessor): image = self.rescale( image=image, scale=rescale_factor, - input_data_format=input_data_format, + input_data_format=data_format, data_format=data_format, ) @@ -744,7 +751,7 @@ class MllamaImageProcessor(BaseImageProcessor): image=image, mean=image_mean, std=image_std, - input_data_format=input_data_format, + input_data_format=data_format, data_format=data_format, ) diff --git a/tests/models/mllama/test_image_processing_mllama.py b/tests/models/mllama/test_image_processing_mllama.py index 351f1f16f2..4b7fbcb81d 100644 --- a/tests/models/mllama/test_image_processing_mllama.py +++ b/tests/models/mllama/test_image_processing_mllama.py @@ -224,6 +224,36 @@ class MllamaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) ) + def test_call_channels_last(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + + # a white 1x1 pixel RGB image + image_inputs = [[np.full(shape=(1, 1, 3), fill_value=1.0, dtype=float)]] + encoded_images = image_processing( + image_inputs, return_tensors="pt", input_data_format="channels_last" + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + def test_ambiguous_channel_pil_image(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + + image_inputs = [[Image.new("RGB", (1, 1))], [Image.new("RGB", (100, 1))]] + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual(tuple(encoded_images.shape), (2, *expected_output_image_shape)) + + def test_resize_impractical_aspect_ratio(self): + # Initialize image_processing + image_processing = self.image_processing_class(**self.image_processor_dict) + # Ensure that no error is raised even if the aspect ratio is impractical + image_inputs = [[Image.new("RGB", (9999999, 1))]] + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + def test_call_pytorch(self): # Initialize image_processing image_processing = self.image_processing_class(**self.image_processor_dict)