Fix donut image processor (#20625)

* fix donut image processor

* Update test values

* Apply lower bound on resizing size

* Add in missing size param

* Resolve resize channel_dimension bug

* Update src/transformers/image_transforms.py
This commit is contained in:
amyeroberts
2022-12-08 19:10:40 +00:00
committed by GitHub
parent e3cc4487fe
commit cf1b8c34cc
4 changed files with 58 additions and 14 deletions

View File

@@ -48,7 +48,11 @@ if is_flax_available():
import jax.numpy as jnp import jax.numpy as jnp
def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: def to_channel_dimension_format(
image: np.ndarray,
channel_dim: Union[ChannelDimension, str],
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
) -> np.ndarray:
""" """
Converts `image` to the channel dimension format specified by `channel_dim`. Converts `image` to the channel dimension format specified by `channel_dim`.
@@ -64,9 +68,11 @@ def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDim
if not isinstance(image, np.ndarray): if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}") raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
current_channel_dim = infer_channel_dimension_format(image) if input_channel_dim is None:
input_channel_dim = infer_channel_dimension_format(image)
target_channel_dim = ChannelDimension(channel_dim) target_channel_dim = ChannelDimension(channel_dim)
if current_channel_dim == target_channel_dim: if input_channel_dim == target_channel_dim:
return image return image
if target_channel_dim == ChannelDimension.FIRST: if target_channel_dim == ChannelDimension.FIRST:
@@ -152,6 +158,7 @@ def to_pil_image(
return PIL.Image.fromarray(image) return PIL.Image.fromarray(image)
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
def get_resize_output_image_size( def get_resize_output_image_size(
input_image: np.ndarray, input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]], size: Union[int, Tuple[int, int], List[int], Tuple[int]],
@@ -202,9 +209,6 @@ def get_resize_output_image_size(
short, long = (width, height) if width <= height else (height, width) short, long = (width, height) if width <= height else (height, width)
requested_new_short = size requested_new_short = size
if short == requested_new_short:
return (height, width)
new_short, new_long = requested_new_short, int(requested_new_short * long / short) new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None: if max_size is not None:
@@ -271,7 +275,10 @@ def resize(
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
# so we need to add it back if necessary. # so we need to add it back if necessary.
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
resized_image = to_channel_dimension_format(resized_image, data_format) # The image is always in channels last format after converting from a PIL image
resized_image = to_channel_dimension_format(
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
)
return resized_image return resized_image

View File

@@ -210,7 +210,8 @@ class DonutImageProcessor(BaseImageProcessor):
**kwargs **kwargs
) -> np.ndarray: ) -> np.ndarray:
""" """
Resize the image to the specified size using thumbnail method. Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
corresponding dimension of the specified size.
Args: Args:
image (`np.ndarray`): image (`np.ndarray`):
@@ -222,8 +223,24 @@ class DonutImageProcessor(BaseImageProcessor):
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*): data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
The data format of the output image. If unset, the same format as the input image is used. The data format of the output image. If unset, the same format as the input image is used.
""" """
output_size = (size["height"], size["width"]) input_height, input_width = get_image_size(image)
return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs) output_height, output_width = size["height"], size["width"]
# We always resize to the smallest of either the input or output size.
height = min(input_height, output_height)
width = min(input_width, output_width)
if height == input_height and width == input_width:
return image
if input_height > input_width:
width = int(input_width * height / input_height)
elif input_width > input_height:
height = int(input_height * width / input_width)
return resize(
image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs
)
def resize( def resize(
self, self,
@@ -250,7 +267,8 @@ class DonutImageProcessor(BaseImageProcessor):
size = get_size_dict(size) size = get_size_dict(size)
shortest_edge = min(size["height"], size["width"]) shortest_edge = min(size["height"], size["width"])
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False) output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs) resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
return resized_image
def rescale( def rescale(
self, self,
@@ -403,7 +421,7 @@ class DonutImageProcessor(BaseImageProcessor):
images = [to_numpy_array(image) for image in images] images = [to_numpy_array(image) for image in images]
if do_align_long_axis: if do_align_long_axis:
images = [self.align_long_axis(image) for image in images] images = [self.align_long_axis(image, size=size) for image in images]
if do_resize: if do_resize:
images = [self.resize(image=image, size=size, resample=resample) for image in images] images = [self.resize(image=image, size=size, resample=resample) for image in images]

View File

@@ -836,7 +836,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size([1, 1, 57532]) expected_shape = torch.Size([1, 1, 57532])
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device) expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
# step 2: generation # step 2: generation
@@ -872,7 +872,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
self.assertEqual(len(outputs.scores), 11) self.assertEqual(len(outputs.scores), 11)
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4 outputs.scores[0][0, :3], torch.tensor([5.6019, -3.5070, 13.7123], device=torch_device), atol=1e-4
) )
) )

View File

@@ -184,6 +184,25 @@ class ImageTransformsTester(unittest.TestCase):
image = np.random.randint(0, 256, (3, 50, 40)) image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17)) self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
# Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 3, 20))
# Defaults to input format - channels last
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 20, 3))
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20), data_format="channels_last")
self.assertEqual(resized_image.shape, (3, 20, 3))
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20), data_format="channels_first")
self.assertEqual(resized_image.shape, (3, 3, 20))
def test_resize(self): def test_resize(self):
image = np.random.randint(0, 256, (3, 224, 224)) image = np.random.randint(0, 256, (3, 224, 224))