🚨 Update image_processing_vitmatte.py (#30566)
* Update image_processing_vitmatte.py * add test * [run-slow]vitmatte
This commit is contained in:
@@ -133,9 +133,9 @@ class VitMatteImageProcessor(BaseImageProcessor):
|
|||||||
|
|
||||||
height, width = get_image_size(image, input_data_format)
|
height, width = get_image_size(image, input_data_format)
|
||||||
|
|
||||||
if height % size_divisibility != 0 or width % size_divisibility != 0:
|
pad_height = 0 if height % size_divisibility == 0 else size_divisibility - height % size_divisibility
|
||||||
pad_height = size_divisibility - height % size_divisibility
|
pad_width = 0 if width % size_divisibility == 0 else size_divisibility - width % size_divisibility
|
||||||
pad_width = size_divisibility - width % size_divisibility
|
if pad_width + pad_height > 0:
|
||||||
padding = ((0, pad_height), (0, pad_width))
|
padding = ((0, pad_height), (0, pad_width))
|
||||||
image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format)
|
image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format)
|
||||||
|
|
||||||
|
|||||||
@@ -192,3 +192,7 @@ class VitMatteImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
|||||||
image = np.random.randn(3, 249, 491)
|
image = np.random.randn(3, 249, 491)
|
||||||
images = image_processing.pad_image(image)
|
images = image_processing.pad_image(image)
|
||||||
assert images.shape == (3, 256, 512)
|
assert images.shape == (3, 256, 512)
|
||||||
|
|
||||||
|
image = np.random.randn(3, 249, 512)
|
||||||
|
images = image_processing.pad_image(image)
|
||||||
|
assert images.shape == (3, 256, 512)
|
||||||
|
|||||||
Reference in New Issue
Block a user