diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index b493215ac7..cb02ed2874 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -49,6 +49,7 @@ from .utils import ( is_vision_available, logging, ) +from .utils.import_utils import is_rocm_platform if is_vision_available(): @@ -280,8 +281,34 @@ class BaseImageProcessorFast(BaseImageProcessor): "Size must contain 'height' and 'width' keys, or 'max_height' and 'max_width', or 'shortest_edge' key. Got" f" {size}." ) + # This is a workaround to avoid a bug in torch.compile when dealing with uint8 on AMD MI3XX GPUs + # Tracked in PyTorch issue: https://github.com/pytorch/pytorch/issues/155209 + # TODO: remove this once the bug is fixed (detected with torch==2.7.0+git1fee196, torchvision==0.22.0+9eb57cd) + if torch.compiler.is_compiling() and is_rocm_platform(): + return self.compile_friendly_resize(image, new_size, interpolation, antialias) return F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + @staticmethod + def compile_friendly_resize( + image: "torch.Tensor", + new_size: tuple[int, int], + interpolation: Optional["F.InterpolationMode"] = None, + antialias: bool = True, + ) -> "torch.Tensor": + """ + A wrapper around `F.resize` so that it is compatible with torch.compile when the image is a uint8 tensor. + """ + if image.dtype == torch.uint8: + image = image.float() / 256 + image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + image = image * 256 + image = torch.where(image > 255, 255, image) + image = torch.where(image < 0, 0, image) + image = image.round().to(torch.uint8) + else: + image = F.resize(image, new_size, interpolation=interpolation, antialias=antialias) + return image + def rescale( self, image: "torch.Tensor", diff --git a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py index 2eac0fe337..95ce3885ca 100644 --- a/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py +++ b/src/transformers/models/bridgetower/image_processing_bridgetower_fast.py @@ -165,13 +165,18 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast): raise ValueError(f"The `size` dictionary must contain the key `shortest_edge`. Got {size.keys()}") shorter = size.shortest_edge longer = int(1333 / 800 * shorter) - output_size = get_resize_output_image_size( + output_height, output_width = get_resize_output_image_size( image, shorter=shorter, longer=longer, size_divisor=size_divisor, ) - return F.resize(image, output_size, interpolation=interpolation, antialias=antialias) + return super().resize( + image=image, + size=SizeDict(height=output_height, width=output_width), + interpolation=interpolation, + antialias=antialias, + ) def center_crop( self, diff --git a/src/transformers/models/llava_next/image_processing_llava_next_fast.py b/src/transformers/models/llava_next/image_processing_llava_next_fast.py index 3356f514ed..2d09548592 100644 --- a/src/transformers/models/llava_next/image_processing_llava_next_fast.py +++ b/src/transformers/models/llava_next/image_processing_llava_next_fast.py @@ -137,7 +137,11 @@ class LlavaNextImageProcessorFast(BaseImageProcessorFast): new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) # Resize the image - resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + resized_image = self.resize( + image=image, + size=SizeDict(height=new_height, width=new_width), + interpolation=interpolation, + ) return resized_image diff --git a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py index 6eba44938c..9a727a62b3 100644 --- a/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +++ b/src/transformers/models/llava_onevision/image_processing_llava_onevision_fast.py @@ -142,7 +142,11 @@ class LlavaOnevisionImageProcessorFast(BaseImageProcessorFast): new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format) # Resize the image - resized_image = F.resize(image, (new_height, new_width), interpolation=interpolation) + resized_image = self.resize( + image=image, + size=SizeDict(height=new_height, width=new_width), + interpolation=interpolation, + ) return resized_image diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py index 762ed117df..2c947e758f 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py @@ -203,8 +203,10 @@ class Qwen2VLImageProcessorFast(BaseImageProcessorFast): min_pixels=size["shortest_edge"], max_pixels=size["longest_edge"], ) - stacked_images = F.resize( - stacked_images, size=(resized_height, resized_width), interpolation=interpolation + stacked_images = self.resize( + image=stacked_images, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, ) resized_images_grouped[shape] = stacked_images resized_images = reorder_images(resized_images_grouped, grouped_images_index) diff --git a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py index 5640b8d333..6eac7efedf 100644 --- a/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/video_processing_qwen2_vl.py @@ -250,8 +250,10 @@ class Qwen2VLVideoProcessor(BaseVideoProcessor): min_pixels=min_pixels, max_pixels=max_pixels, ) - stacked_videos = F.resize( - stacked_videos, size=(resized_height, resized_width), interpolation=interpolation + stacked_videos = self.resize( + image=stacked_videos, + size=SizeDict(height=resized_height, width=resized_width), + interpolation=interpolation, ) resized_videos_grouped[shape] = stacked_videos resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index) diff --git a/tests/models/internvl/test_modeling_internvl.py b/tests/models/internvl/test_modeling_internvl.py index 963e840e0b..d7e1132be6 100644 --- a/tests/models/internvl/test_modeling_internvl.py +++ b/tests/models/internvl/test_modeling_internvl.py @@ -705,6 +705,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase): ("xpu", 3): torch.tensor([-9.8750, -0.5703, 1.4297, -10.3125, -10.3125], dtype=torch.float16), ("cuda", 7): torch.tensor([-9.8750, -0.4861, 1.4648, -10.3359, -10.3359], dtype=torch.float16), ("cuda", 8): torch.tensor([-9.8906, -0.4995, 1.4473, -10.3359, -10.3438], dtype=torch.float16), + ("rocm", (9, 5)): torch.tensor([ -9.8906, -0.4976, 1.4502, -10.3359, -10.3438], dtype=torch.float16), } ) # fmt: skip expected_logits = torch.tensor(expected_logits_all.get_expectation(), dtype=torch.float16)