Add option to resize like torchvision's Resize (#15419)
* Add torchvision's resize * Rename torch_resize to default_to_square * Apply suggestions from code review * Add support for default_to_square and tuple of length 1
This commit is contained in:
@@ -184,7 +184,7 @@ class ImageFeatureExtractionMixin:
|
|||||||
else:
|
else:
|
||||||
return (image - mean) / std
|
return (image - mean) / std
|
||||||
|
|
||||||
def resize(self, image, size, resample=PIL.Image.BILINEAR):
|
def resize(self, image, size, resample=PIL.Image.BILINEAR, default_to_square=True, max_size=None):
|
||||||
"""
|
"""
|
||||||
Resizes `image`. Note that this will trigger a conversion of `image` to a PIL Image.
|
Resizes `image`. Note that this will trigger a conversion of `image` to a PIL Image.
|
||||||
|
|
||||||
@@ -192,19 +192,58 @@ class ImageFeatureExtractionMixin:
|
|||||||
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
|
||||||
The image to resize.
|
The image to resize.
|
||||||
size (`int` or `Tuple[int, int]`):
|
size (`int` or `Tuple[int, int]`):
|
||||||
The size to use for resizing the image.
|
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
|
||||||
|
matched to this.
|
||||||
|
|
||||||
|
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
|
||||||
|
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
|
||||||
|
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
||||||
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
|
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
|
||||||
The filter to user for resampling.
|
The filter to user for resampling.
|
||||||
|
default_to_square (`bool`, *optional*, defaults to `True`):
|
||||||
|
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
|
||||||
|
square (`size`,`size`). If set to `False`, will replicate
|
||||||
|
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
|
||||||
|
with support for resizing only the smallest edge and providing an optional `max_size`.
|
||||||
|
max_size (`int`, *optional*, defaults to `None`):
|
||||||
|
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
|
||||||
|
greater than `max_size` after being resized according to `size`, then the image is resized again so
|
||||||
|
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
|
||||||
|
edge may be shorter than `size`. Only used if `default_to_square` is `False`.
|
||||||
"""
|
"""
|
||||||
self._ensure_format_supported(image)
|
self._ensure_format_supported(image)
|
||||||
|
|
||||||
if isinstance(size, int):
|
|
||||||
size = (size, size)
|
|
||||||
elif isinstance(size, list):
|
|
||||||
size = tuple(size)
|
|
||||||
if not isinstance(image, PIL.Image.Image):
|
if not isinstance(image, PIL.Image.Image):
|
||||||
image = self.to_pil_image(image)
|
image = self.to_pil_image(image)
|
||||||
|
|
||||||
|
if isinstance(size, list):
|
||||||
|
size = tuple(size)
|
||||||
|
|
||||||
|
if isinstance(size, int) or len(size) == 1:
|
||||||
|
if default_to_square:
|
||||||
|
size = (size, size) if isinstance(size, int) else (size[0], size[0])
|
||||||
|
else:
|
||||||
|
width, height = image.size
|
||||||
|
# specified size only for the smallest edge
|
||||||
|
short, long = (width, height) if width <= height else (height, width)
|
||||||
|
requested_new_short = size if isinstance(size, int) else size[0]
|
||||||
|
|
||||||
|
if short == requested_new_short:
|
||||||
|
return image
|
||||||
|
|
||||||
|
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
||||||
|
|
||||||
|
if max_size is not None:
|
||||||
|
if max_size <= requested_new_short:
|
||||||
|
raise ValueError(
|
||||||
|
f"max_size = {max_size} must be strictly greater than the requested "
|
||||||
|
f"size for the smaller edge size = {size}"
|
||||||
|
)
|
||||||
|
if new_long > max_size:
|
||||||
|
new_short, new_long = int(max_size * new_short / new_long), max_size
|
||||||
|
|
||||||
|
size = (new_short, new_long) if width <= height else (new_long, new_short)
|
||||||
|
|
||||||
return image.resize(size, resample=resample)
|
return image.resize(size, resample=resample)
|
||||||
|
|
||||||
def center_crop(self, image, size):
|
def center_crop(self, image, size):
|
||||||
|
|||||||
@@ -219,7 +219,7 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
|
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
|
||||||
self.assertEqual(resized_image1.size, (8, 16))
|
self.assertEqual(resized_image1.size, (8, 16))
|
||||||
|
|
||||||
# Passing and array converts it to a PIL Image.
|
# Passing an array converts it to a PIL Image.
|
||||||
resized_image2 = feature_extractor.resize(array, 8)
|
resized_image2 = feature_extractor.resize(array, 8)
|
||||||
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
|
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
|
||||||
self.assertEqual(resized_image2.size, (8, 8))
|
self.assertEqual(resized_image2.size, (8, 8))
|
||||||
@@ -230,6 +230,57 @@ class ImageFeatureExtractionTester(unittest.TestCase):
|
|||||||
self.assertEqual(resized_image3.size, (8, 16))
|
self.assertEqual(resized_image3.size, (8, 16))
|
||||||
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
|
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
|
||||||
|
|
||||||
|
def test_resize_image_and_array_non_default_to_square(self):
|
||||||
|
feature_extractor = ImageFeatureExtractionMixin()
|
||||||
|
|
||||||
|
heights_widths = [
|
||||||
|
# height, width
|
||||||
|
# square image
|
||||||
|
(28, 28),
|
||||||
|
(27, 27),
|
||||||
|
# rectangular image: h < w
|
||||||
|
(28, 34),
|
||||||
|
(29, 35),
|
||||||
|
# rectangular image: h > w
|
||||||
|
(34, 28),
|
||||||
|
(35, 29),
|
||||||
|
]
|
||||||
|
|
||||||
|
# single integer or single integer in tuple/list
|
||||||
|
sizes = [22, 27, 28, 36, [22], (27,)]
|
||||||
|
|
||||||
|
for (height, width), size in zip(heights_widths, sizes):
|
||||||
|
for max_size in (None, 37, 1000):
|
||||||
|
image = get_random_image(height, width)
|
||||||
|
array = np.array(image)
|
||||||
|
|
||||||
|
size = size[0] if isinstance(size, (list, tuple)) else size
|
||||||
|
# Size can be an int or a tuple of ints.
|
||||||
|
# If size is an int, smaller edge of the image will be matched to this number.
|
||||||
|
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
|
||||||
|
if height < width:
|
||||||
|
exp_w, exp_h = (int(size * width / height), size)
|
||||||
|
if max_size is not None and max_size < exp_w:
|
||||||
|
exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
|
||||||
|
elif width < height:
|
||||||
|
exp_w, exp_h = (size, int(size * height / width))
|
||||||
|
if max_size is not None and max_size < exp_h:
|
||||||
|
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
|
||||||
|
else:
|
||||||
|
exp_w, exp_h = (size, size)
|
||||||
|
if max_size is not None and max_size < size:
|
||||||
|
exp_w, exp_h = max_size, max_size
|
||||||
|
|
||||||
|
resized_image = feature_extractor.resize(image, size=size, default_to_square=False, max_size=max_size)
|
||||||
|
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
|
||||||
|
self.assertEqual(resized_image.size, (exp_w, exp_h))
|
||||||
|
|
||||||
|
# Passing an array converts it to a PIL Image.
|
||||||
|
resized_image2 = feature_extractor.resize(array, size=size, default_to_square=False, max_size=max_size)
|
||||||
|
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
|
||||||
|
self.assertEqual(resized_image2.size, (exp_w, exp_h))
|
||||||
|
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_resize_tensor(self):
|
def test_resize_tensor(self):
|
||||||
feature_extractor = ImageFeatureExtractionMixin()
|
feature_extractor = ImageFeatureExtractionMixin()
|
||||||
|
|||||||
Reference in New Issue
Block a user