From e3b70b0d1c15c87ba2010b00830fbd92b2c50252 Mon Sep 17 00:00:00 2001 From: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Date: Mon, 12 May 2025 15:13:40 -0400 Subject: [PATCH] Refactor image processor phi4 (#36976) * refactor image processor phi4 * nits fast image proc * add image tests phi4 * Fix image processing tests * update integration tests * remove revision and add comment in integration tests --- .../models/auto/image_processing_auto.py | 2 +- .../image_processing_phi4_multimodal_fast.py | 236 +++++++------- .../test_image_processing_phi4_multimodal.py | 307 ++++++++++++++++++ .../test_modeling_phi4_multimodal.py | 17 +- tests/test_image_processing_common.py | 2 +- 5 files changed, 436 insertions(+), 128 deletions(-) create mode 100644 tests/models/phi4_multimodal/test_image_processing_phi4_multimodal.py diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 41d446cf4e..7141cd2e9e 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -128,7 +128,7 @@ else: ("owlvit", ("OwlViTImageProcessor", "OwlViTImageProcessorFast")), ("paligemma", ("SiglipImageProcessor", "SiglipImageProcessorFast")), ("perceiver", ("PerceiverImageProcessor", "PerceiverImageProcessorFast")), - ("phi4_multimodal", "Phi4MultimodalImageProcessorFast"), + ("phi4_multimodal", ("Phi4MultimodalImageProcessorFast",)), ("pix2struct", ("Pix2StructImageProcessor",)), ("pixtral", ("PixtralImageProcessor", "PixtralImageProcessorFast")), ("poolformer", ("PoolFormerImageProcessor", "PoolFormerImageProcessorFast")), diff --git a/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py b/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py index 813f89dbea..2273293b4c 100644 --- a/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +++ b/src/transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py @@ -12,53 +12,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Processor class for Phi4Multimodal -""" - import math from typing import List, Optional, Union import torch -from torchvision.transforms import functional as F from ...image_processing_utils_fast import ( BaseImageProcessorFast, BatchFeature, DefaultFastImageProcessorKwargs, Unpack, - convert_to_rgb, ) -from ...image_utils import ImageInput, make_flat_list_of_images, valid_images -from ...utils import TensorType, logging +from ...image_utils import ImageInput, SizeDict +from ...utils import ( + TensorType, + auto_docstring, + is_torchvision_available, + is_torchvision_v2_available, + is_vision_available, + logging, +) +if is_vision_available(): + from ...image_utils import PILImageResampling +if is_torchvision_available(): + if is_torchvision_v2_available(): + from torchvision.transforms.v2 import functional as F + else: + from torchvision.transforms import functional as F + logger = logging.get_logger(__name__) class Phi4MultimodalFastImageProcessorKwargs(DefaultFastImageProcessorKwargs): - image_size: Optional[int] + r""" + patch_size (`int`, *optional*): + The size of the patch. + dynamic_hd (`int`, *optional*): + The maximum number of crops per image. + """ + patch_size: Optional[int] dynamic_hd: Optional[int] +@auto_docstring class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast): - r""" - Constructs a Phi4Multimodal image processor. - """ - - image_size = 448 + resample = PILImageResampling.BICUBIC + size = {"height": 448, "width": 448} patch_size = 14 dynamic_hd = 36 image_mean = [0.5, 0.5, 0.5] image_std = [0.5, 0.5, 0.5] - valid_init_kwargs = Phi4MultimodalFastImageProcessorKwargs + do_resize = True + do_rescale = True + do_normalize = True + do_convert_rgb = True + valid_kwargs = Phi4MultimodalFastImageProcessorKwargs model_input_names = ["image_pixel_values", "image_sizes", "image_attention_mask"] def __init__(self, **kwargs: Unpack[Phi4MultimodalFastImageProcessorKwargs]): super().__init__(**kwargs) - def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height): + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float("inf") best_ratio = (1, 1) area = width * height @@ -69,15 +86,12 @@ class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast): best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: - if area > 0.5 * self.image_size * self.image_size * ratio[0] * ratio[1]: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio - def dynamic_preprocess(self, image, max_num=36, min_num=1): - image_size = self.image_size - patch_size = self.patch_size - mask_size = image_size // patch_size - orig_width, orig_height = image.size + def dynamic_preprocess(self, image, image_size, patch_size, mask_size, max_num=36, min_num=1): + orig_height, orig_width = image.shape[-2:] w_crop_num = math.ceil(orig_width / float(image_size)) h_crop_num = math.ceil(orig_height / float(image_size)) @@ -95,7 +109,9 @@ class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast): target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target - target_aspect_ratio = self.find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height) + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] @@ -148,113 +164,101 @@ class Phi4MultimodalImageProcessorFast(BaseImageProcessorFast): masks = torch.cat([masks, pad], dim=0) return masks + @auto_docstring def preprocess( self, images: ImageInput, + **kwargs: Unpack[Phi4MultimodalFastImageProcessorKwargs], + ) -> BatchFeature: + return super().preprocess(images, **kwargs) + + def _preprocess( + self, + images: List["torch.Tensor"], + size: SizeDict, + interpolation: Optional["F.InterpolationMode"], + patch_size: int, + dynamic_hd: int, + do_rescale: bool, + rescale_factor: Optional[float], + do_normalize: bool, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, ): - """ - Args: - images (`ImageInput`): - Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If - passing in images with pixel values between 0 and 1, set `do_rescale=False`. - image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): - Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. - image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): - Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image. - return_tensors (`str` or `TensorType`, *optional*): - The type of tensors to return. Can be one of: - - Unset: Return a list of `np.ndarray`. - - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. - - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. - - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. - - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. - """ - image_mean = image_mean if image_mean is not None else self.image_mean - image_std = image_std if image_std is not None else self.image_std - - images = make_flat_list_of_images(images) - if not valid_images(images): - raise ValueError( - "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " - "torch.Tensor, tf.Tensor or jax.ndarray." + if size.height != size.width: + raise ValueError("Phi4MultimodalFastImageProcessor only supports square sizes.") + mask_size = size.height // patch_size + images_transformed = [] + masks_transformed = [] + images_tokens = [] + image_sizes = [] + for image in images: + resized_image, attention_mask = self.dynamic_preprocess( + image, size.height, patch_size, mask_size, max_num=dynamic_hd ) - images = [convert_to_rgb(image) for image in images] - - image_size = self.image_size - patch_size = self.patch_size - mask_size = image_size // patch_size - imgs_and_masks = [self.dynamic_preprocess(image, max_num=self.dynamic_hd) for image in images] - images, image_attention_masks = [x[0] for x in imgs_and_masks], [x[1] for x in imgs_and_masks] - - images = [F.to_tensor(image) for image in images] - hd_images = [F.normalize(image, image_mean, image_std) for image in images] - global_image = [ - torch.nn.functional.interpolate( - image.unsqueeze(0).float(), - size=(image_size, image_size), - mode="bicubic", - ).to(image.dtype) - for image in hd_images - ] - - shapes = [[image.size(1), image.size(2)] for image in hd_images] - mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] - global_attention_mask = [torch.ones((1, mask_size, mask_size)) for _ in hd_images] - - hd_images_reshape = [] - for im, (h, w) in zip(hd_images, shapes): - im = im.reshape(1, 3, h // image_size, image_size, w // image_size, image_size) - im = im.permute(0, 2, 4, 1, 3, 5) - im = im.reshape(-1, 3, image_size, image_size) - hd_images_reshape.append(im.contiguous()) - - attention_masks_reshape = [] - for mask, (h, w) in zip(image_attention_masks, mask_shapes): - mask = mask.reshape(h // mask_size, mask_size, w // mask_size, mask_size) - mask = mask.transpose(1, 2) - mask = mask.reshape(-1, mask_size, mask_size) - attention_masks_reshape.append(mask.contiguous()) - - downsample_attention_masks = [] - for mask, (h, w) in zip(attention_masks_reshape, mask_shapes): - mask = mask[:, 0::2, 0::2] - mask = mask.reshape( - h // mask_size, w // mask_size, mask_size // 2 + mask_size % 2, mask_size // 2 + mask_size % 2 + processed_image = self.rescale_and_normalize( + resized_image, do_rescale, rescale_factor, do_normalize, image_mean, image_std ) - mask = mask.transpose(1, 2) - mask = mask.reshape(mask.size(0) * mask.size(1), mask.size(2) * mask.size(3)) - downsample_attention_masks.append(mask) + global_image = self.resize(processed_image, size, interpolation=interpolation, antialias=False) + height, width = processed_image.shape[-2:] + mask_height, mask_width = attention_mask.shape[-2:] + global_attention_mask = torch.ones((1, mask_size, mask_size)) - num_img_tokens = [ - 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks - ] + hd_image_reshape = processed_image.reshape( + 1, 3, height // size.height, size.height, width // size.width, size.width + ) + hd_image_reshape = hd_image_reshape.permute(0, 2, 4, 1, 3, 5) + hd_image_reshape = hd_image_reshape.reshape(-1, 3, size.height, size.width).contiguous() - hd_images_reshape = [ - torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape) - ] - hd_masks_reshape = [ - torch.cat([_global_mask] + [_mask], dim=0) - for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape) - ] - max_crops = max([img.size(0) for img in hd_images_reshape]) - image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] - image_transformed = torch.stack(image_transformed, dim=0) - mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] - mask_transformed = torch.stack(mask_transformed, dim=0) + attention_mask_reshape = attention_mask.reshape( + mask_height // mask_size, mask_size, mask_width // mask_size, mask_size + ) + attention_mask_reshape = attention_mask_reshape.transpose(1, 2) + attention_mask_reshape = attention_mask_reshape.reshape(-1, mask_size, mask_size).contiguous() - returned_input_image_embeds = image_transformed - returned_image_sizes = torch.tensor(shapes, dtype=torch.long) - returned_image_attention_mask = mask_transformed - returned_num_img_tokens = num_img_tokens + downsample_attention_mask = attention_mask_reshape[:, 0::2, 0::2] + downsample_attention_mask = downsample_attention_mask.reshape( + mask_height // mask_size, + mask_width // mask_size, + mask_size // 2 + mask_size % 2, + mask_size // 2 + mask_size % 2, + ) + downsample_attention_mask = downsample_attention_mask.transpose(1, 2) + downsample_attention_mask = downsample_attention_mask.reshape( + downsample_attention_mask.size(0) * downsample_attention_mask.size(1), + downsample_attention_mask.size(2) * downsample_attention_mask.size(3), + ) + + num_img_tokens = ( + 256 + + 1 + + int(downsample_attention_mask.sum().item()) + + int(downsample_attention_mask[:, 0].sum().item()) + + 16 + ) + + hd_image_reshape = torch.cat([global_image.unsqueeze(0), hd_image_reshape], dim=0) + hd_attention_mask_reshape = torch.cat([global_attention_mask, attention_mask_reshape], dim=0) + + images_transformed.append(hd_image_reshape) + masks_transformed.append(hd_attention_mask_reshape) + images_tokens.append(num_img_tokens) + image_sizes.append([height, width]) + max_crops = hd_image_reshape.size(0) + max_crops = max([img.size(0) for img in images_transformed]) + images_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in images_transformed] + images_transformed = torch.stack(images_transformed, dim=0) + masks_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in masks_transformed] + masks_transformed = torch.stack(masks_transformed, dim=0) + image_sizes = torch.tensor(image_sizes, dtype=torch.long) data = { - "image_pixel_values": returned_input_image_embeds, - "image_sizes": returned_image_sizes, - "image_attention_mask": returned_image_attention_mask, - "num_img_tokens": returned_num_img_tokens, + "image_pixel_values": images_transformed, + "image_sizes": image_sizes, + "image_attention_mask": masks_transformed, + "num_img_tokens": images_tokens, } return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/tests/models/phi4_multimodal/test_image_processing_phi4_multimodal.py b/tests/models/phi4_multimodal/test_image_processing_phi4_multimodal.py new file mode 100644 index 0000000000..3ad87b5780 --- /dev/null +++ b/tests/models/phi4_multimodal/test_image_processing_phi4_multimodal.py @@ -0,0 +1,307 @@ +# coding=utf-8 +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import math +import unittest +import warnings + +import numpy as np +from packaging import version + +from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + +if is_vision_available(): + from PIL import Image + + if is_torchvision_available(): + from transformers import Phi4MultimodalImageProcessorFast + + +class Phi4MultimodalImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + image_size=100, + min_resolution=30, + max_resolution=400, + dynamic_hd=36, + do_resize=True, + size=None, + patch_size=14, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + do_convert_rgb=True, + ): + super().__init__() + size = size if size is not None else {"height": 100, "width": 100} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.image_size = image_size + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.dynamic_hd = dynamic_hd + self.do_resize = do_resize + self.size = size + self.patch_size = patch_size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.do_convert_rgb = do_convert_rgb + + def prepare_image_processor_dict(self): + return { + "do_resize": self.do_resize, + "size": self.size, + "patch_size": self.patch_size, + "dynamic_hd": self.dynamic_hd, + "do_normalize": self.do_normalize, + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_convert_rgb": self.do_convert_rgb, + } + + def expected_output_image_shape(self, images): + max_num_patches = 0 + for image in images: + if isinstance(image, Image.Image): + width, height = image.size + elif isinstance(image, np.ndarray): + height, width = image.shape[:2] + elif isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + w_crop_num = math.ceil(width / float(self.size["width"])) + h_crop_num = math.ceil(height / float(self.size["height"])) + num_patches = min(w_crop_num * h_crop_num + 1, self.dynamic_hd) + max_num_patches = max(max_num_patches, num_patches) + num_patches = max_num_patches + return num_patches, self.num_channels, self.size["height"], self.size["width"] + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class Phi4MultimodalImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + fast_image_processing_class = Phi4MultimodalImageProcessorFast if is_torchvision_available() else None + test_slow_image_processor = False + + def setUp(self): + super().setUp() + self.image_processor_tester = Phi4MultimodalImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + self.assertTrue(hasattr(image_processing, "do_center_crop")) + self.assertTrue(hasattr(image_processing, "center_crop")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_convert_rgb")) + + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"height": 100, "width": 100}) + + image_processor = image_processing_class.from_dict(self.image_processor_dict, size=42) + self.assertEqual(image_processor.size, {"height": 42, "width": 42}) + + @unittest.skip(reason="Phi4MultimodalImageProcessorFast doesn't treat 4 channel PIL and numpy consistently yet") + def test_call_numpy_4_channels(self): + pass + + def test_cast_dtype_device(self): + for image_processing_class in self.image_processor_list: + if self.test_cast_dtype is not None: + # Initialize image_processor + image_processor = image_processing_class(**self.image_processor_dict) + + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + encoding = image_processor(image_inputs, return_tensors="pt") + # for layoutLM compatibility + self.assertEqual(encoding.image_pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.image_pixel_values.dtype, torch.float32) + + encoding = image_processor(image_inputs, return_tensors="pt").to(torch.float16) + self.assertEqual(encoding.image_pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.image_pixel_values.dtype, torch.float16) + + encoding = image_processor(image_inputs, return_tensors="pt").to("cpu", torch.bfloat16) + self.assertEqual(encoding.image_pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.image_pixel_values.dtype, torch.bfloat16) + + with self.assertRaises(TypeError): + _ = image_processor(image_inputs, return_tensors="pt").to(torch.bfloat16, "cpu") + + # Try with text + image feature + encoding = image_processor(image_inputs, return_tensors="pt") + encoding.update({"input_ids": torch.LongTensor([[1, 2, 3], [4, 5, 6]])}) + encoding = encoding.to(torch.float16) + + self.assertEqual(encoding.image_pixel_values.device, torch.device("cpu")) + self.assertEqual(encoding.image_pixel_values.dtype, torch.float16) + self.assertEqual(encoding.input_ids.dtype, torch.long) + + def test_call_pil(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").image_pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").image_pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_numpy(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").image_pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").image_pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual( + tuple(encoded_images.shape), (self.image_processor_tester.batch_size, *expected_output_image_shape) + ) + + def test_call_pytorch(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").image_pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), (1, *expected_output_image_shape)) + + # Test batched + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + encoded_images = image_processing(image_inputs, return_tensors="pt").image_pixel_values + self.assertEqual( + tuple(encoded_images.shape), + (self.image_processor_tester.batch_size, *expected_output_image_shape), + ) + + def test_image_processor_preprocess_arguments(self): + is_tested = False + + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class(**self.image_processor_dict) + + # validation done by _valid_processor_keys attribute + if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"): + preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args + preprocess_parameter_names.remove("self") + preprocess_parameter_names.sort() + valid_processor_keys = image_processor._valid_processor_keys + valid_processor_keys.sort() + self.assertEqual(preprocess_parameter_names, valid_processor_keys) + is_tested = True + + # validation done by @filter_out_non_signature_kwargs decorator + if hasattr(image_processor.preprocess, "_filter_out_non_signature_kwargs"): + if hasattr(self.image_processor_tester, "prepare_image_inputs"): + inputs = self.image_processor_tester.prepare_image_inputs() + elif hasattr(self.image_processor_tester, "prepare_video_inputs"): + inputs = self.image_processor_tester.prepare_video_inputs() + else: + self.skipTest(reason="No valid input preparation method found") + + with warnings.catch_warnings(record=True) as raised_warnings: + warnings.simplefilter("always") + image_processor(inputs, extra_argument=True) + + messages = " ".join([str(w.message) for w in raised_warnings]) + self.assertGreaterEqual(len(raised_warnings), 1) + self.assertIn("extra_argument", messages) + is_tested = True + + if not is_tested: + self.skipTest(reason="No validation found for `preprocess` method") + + @slow + def test_can_compile_fast_image_processor(self): + if self.fast_image_processing_class is None: + self.skipTest("Skipping compilation test as fast image processor is not defined") + if version.parse(torch.__version__) < version.parse("2.3"): + self.skipTest(reason="This test requires torch >= 2.3 to run.") + + torch.compiler.reset() + input_image = torch.randint(0, 255, (3, 224, 224), dtype=torch.uint8) + image_processor = self.fast_image_processing_class(**self.image_processor_dict) + output_eager = image_processor(input_image, device=torch_device, return_tensors="pt") + + image_processor = torch.compile(image_processor, mode="reduce-overhead") + output_compiled = image_processor(input_image, device=torch_device, return_tensors="pt") + + torch.testing.assert_close( + output_eager.image_pixel_values, output_compiled.image_pixel_values, rtol=1e-4, atol=1e-4 + ) diff --git a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py index ee1c4c5b71..4dd51cc34e 100644 --- a/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py +++ b/tests/models/phi4_multimodal/test_modeling_phi4_multimodal.py @@ -31,12 +31,7 @@ from transformers import ( is_torch_available, is_vision_available, ) -from transformers.testing_utils import ( - require_soundfile, - require_torch, - slow, - torch_device, -) +from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device from transformers.utils import is_soundfile_available from ...generation.test_utils import GenerationTesterMixin @@ -285,6 +280,8 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase): audio_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav" def setUp(self): + # Currently, the Phi-4 checkpoint on the hub is not working with the latest Phi-4 code, so the slow integration tests + # won't pass without using the correct revision (refs/pr/70) self.processor = AutoProcessor.from_pretrained(self.checkpoint_path) self.generation_config = GenerationConfig(max_new_tokens=20, do_sample=False) self.user_token = "<|user|>" @@ -325,7 +322,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase): self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device ) - prompt = f"{self.user_token}<|image_1|>What is shown in this image?{self.end_token}{self.assistant_token}" + prompt = f"{self.user_token}<|image|>What is shown in this image?{self.end_token}{self.assistant_token}" inputs = self.processor(prompt, images=self.image, return_tensors="pt").to(torch_device) output = model.generate( @@ -349,7 +346,7 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase): for i in range(1, 5): url = f"https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-{i}-2048.jpg" images.append(Image.open(requests.get(url, stream=True).raw)) - placeholder += f"<|image_{i}|>" + placeholder += "<|image|>" prompt = f"{self.user_token}{placeholder}Summarize the deck of slides.{self.end_token}{self.assistant_token}" inputs = self.processor(prompt, images, return_tensors="pt").to(torch_device) @@ -371,8 +368,8 @@ class Phi4MultimodalIntegrationTest(unittest.TestCase): self.checkpoint_path, torch_dtype=torch.float16, device_map=torch_device ) - prompt = f"{self.user_token}<|audio_1|>What is happening in this audio?{self.end_token}{self.assistant_token}" - inputs = self.processor(prompt, audios=self.audio, sampling_rate=self.sampling_rate, return_tensors="pt").to( + prompt = f"{self.user_token}<|audio|>What is happening in this audio?{self.end_token}{self.assistant_token}" + inputs = self.processor(prompt, audio=self.audio, sampling_rate=self.sampling_rate, return_tensors="pt").to( torch_device ) diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index b51125fdb6..f70adea169 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -279,7 +279,7 @@ class ImageProcessingTestMixin: saved_file = image_processor_first.save_pretrained(tmpdirname)[0] check_json_file_has_correct_format(saved_file) - use_fast = i == 1 + use_fast = i == 1 or not self.test_slow_image_processor image_processor_second = AutoImageProcessor.from_pretrained(tmpdirname, use_fast=use_fast) self.assertEqual(image_processor_second.to_dict(), image_processor_first.to_dict())