From f95302584b21db464423e0151dd6ce99ca07ccc3 Mon Sep 17 00:00:00 2001 From: Richard Brown <135021519+rb-synth@users.noreply.github.com> Date: Thu, 2 May 2024 12:00:07 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Update=20image=5Fprocessing=5Fvi?= =?UTF-8?q?tmatte.py=20(#30566)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update image_processing_vitmatte.py * add test * [run-slow]vitmatte --- .../models/vitmatte/image_processing_vitmatte.py | 6 +++--- tests/models/vitmatte/test_image_processing_vitmatte.py | 4 ++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/vitmatte/image_processing_vitmatte.py b/src/transformers/models/vitmatte/image_processing_vitmatte.py index d7310bc0dd..6e4465e2db 100644 --- a/src/transformers/models/vitmatte/image_processing_vitmatte.py +++ b/src/transformers/models/vitmatte/image_processing_vitmatte.py @@ -133,9 +133,9 @@ class VitMatteImageProcessor(BaseImageProcessor): height, width = get_image_size(image, input_data_format) - if height % size_divisibility != 0 or width % size_divisibility != 0: - pad_height = size_divisibility - height % size_divisibility - pad_width = size_divisibility - width % size_divisibility + pad_height = 0 if height % size_divisibility == 0 else size_divisibility - height % size_divisibility + pad_width = 0 if width % size_divisibility == 0 else size_divisibility - width % size_divisibility + if pad_width + pad_height > 0: padding = ((0, pad_height), (0, pad_width)) image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format) diff --git a/tests/models/vitmatte/test_image_processing_vitmatte.py b/tests/models/vitmatte/test_image_processing_vitmatte.py index e1009c7592..e86cfde1e5 100644 --- a/tests/models/vitmatte/test_image_processing_vitmatte.py +++ b/tests/models/vitmatte/test_image_processing_vitmatte.py @@ -192,3 +192,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): image = np.random.randn(3, 249, 491) images = image_processing.pad_image(image) assert images.shape == (3, 256, 512) + + image = np.random.randn(3, 249, 512) + images = image_processing.pad_image(image) + assert images.shape == (3, 256, 512)