Fix donut image processor (#20625)
* fix donut image processor * Update test values * Apply lower bound on resizing size * Add in missing size param * Resolve resize channel_dimension bug * Update src/transformers/image_transforms.py
This commit is contained in:
@@ -48,7 +48,11 @@ if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray:
|
||||
def to_channel_dimension_format(
|
||||
image: np.ndarray,
|
||||
channel_dim: Union[ChannelDimension, str],
|
||||
input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Converts `image` to the channel dimension format specified by `channel_dim`.
|
||||
|
||||
@@ -64,9 +68,11 @@ def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDim
|
||||
if not isinstance(image, np.ndarray):
|
||||
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")
|
||||
|
||||
current_channel_dim = infer_channel_dimension_format(image)
|
||||
if input_channel_dim is None:
|
||||
input_channel_dim = infer_channel_dimension_format(image)
|
||||
|
||||
target_channel_dim = ChannelDimension(channel_dim)
|
||||
if current_channel_dim == target_channel_dim:
|
||||
if input_channel_dim == target_channel_dim:
|
||||
return image
|
||||
|
||||
if target_channel_dim == ChannelDimension.FIRST:
|
||||
@@ -152,6 +158,7 @@ def to_pil_image(
|
||||
return PIL.Image.fromarray(image)
|
||||
|
||||
|
||||
# Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
|
||||
def get_resize_output_image_size(
|
||||
input_image: np.ndarray,
|
||||
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
|
||||
@@ -202,9 +209,6 @@ def get_resize_output_image_size(
|
||||
short, long = (width, height) if width <= height else (height, width)
|
||||
requested_new_short = size
|
||||
|
||||
if short == requested_new_short:
|
||||
return (height, width)
|
||||
|
||||
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
|
||||
|
||||
if max_size is not None:
|
||||
@@ -271,7 +275,10 @@ def resize(
|
||||
# If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
|
||||
# so we need to add it back if necessary.
|
||||
resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
|
||||
resized_image = to_channel_dimension_format(resized_image, data_format)
|
||||
# The image is always in channels last format after converting from a PIL image
|
||||
resized_image = to_channel_dimension_format(
|
||||
resized_image, data_format, input_channel_dim=ChannelDimension.LAST
|
||||
)
|
||||
return resized_image
|
||||
|
||||
|
||||
|
||||
@@ -210,7 +210,8 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
**kwargs
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Resize the image to the specified size using thumbnail method.
|
||||
Resize the image to make a thumbnail. The image is resized so that no dimension is larger than any
|
||||
corresponding dimension of the specified size.
|
||||
|
||||
Args:
|
||||
image (`np.ndarray`):
|
||||
@@ -222,8 +223,24 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
data_format (`Optional[Union[str, ChannelDimension]]`, *optional*):
|
||||
The data format of the output image. If unset, the same format as the input image is used.
|
||||
"""
|
||||
output_size = (size["height"], size["width"])
|
||||
return resize(image, size=output_size, resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs)
|
||||
input_height, input_width = get_image_size(image)
|
||||
output_height, output_width = size["height"], size["width"]
|
||||
|
||||
# We always resize to the smallest of either the input or output size.
|
||||
height = min(input_height, output_height)
|
||||
width = min(input_width, output_width)
|
||||
|
||||
if height == input_height and width == input_width:
|
||||
return image
|
||||
|
||||
if input_height > input_width:
|
||||
width = int(input_width * height / input_height)
|
||||
elif input_width > input_height:
|
||||
height = int(input_height * width / input_width)
|
||||
|
||||
return resize(
|
||||
image, size=(height, width), resample=resample, reducing_gap=2.0, data_format=data_format, **kwargs
|
||||
)
|
||||
|
||||
def resize(
|
||||
self,
|
||||
@@ -250,7 +267,8 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
size = get_size_dict(size)
|
||||
shortest_edge = min(size["height"], size["width"])
|
||||
output_size = get_resize_output_image_size(image, size=shortest_edge, default_to_square=False)
|
||||
return resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
resized_image = resize(image, size=output_size, resample=resample, data_format=data_format, **kwargs)
|
||||
return resized_image
|
||||
|
||||
def rescale(
|
||||
self,
|
||||
@@ -403,7 +421,7 @@ class DonutImageProcessor(BaseImageProcessor):
|
||||
images = [to_numpy_array(image) for image in images]
|
||||
|
||||
if do_align_long_axis:
|
||||
images = [self.align_long_axis(image) for image in images]
|
||||
images = [self.align_long_axis(image, size=size) for image in images]
|
||||
|
||||
if do_resize:
|
||||
images = [self.resize(image=image, size=size, resample=resample) for image in images]
|
||||
|
||||
@@ -836,7 +836,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size([1, 1, 57532])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device)
|
||||
expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
# step 2: generation
|
||||
@@ -872,7 +872,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(len(outputs.scores), 11)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4
|
||||
outputs.scores[0][0, :3], torch.tensor([5.6019, -3.5070, 13.7123], device=torch_device), atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -184,6 +184,25 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
image = np.random.randint(0, 256, (3, 50, 40))
|
||||
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
|
||||
|
||||
# Test correct channel dimension is returned if output size if height == 3
|
||||
# Defaults to input format - channels first
|
||||
image = np.random.randint(0, 256, (3, 18, 97))
|
||||
resized_image = resize(image, (3, 20))
|
||||
self.assertEqual(resized_image.shape, (3, 3, 20))
|
||||
|
||||
# Defaults to input format - channels last
|
||||
image = np.random.randint(0, 256, (18, 97, 3))
|
||||
resized_image = resize(image, (3, 20))
|
||||
self.assertEqual(resized_image.shape, (3, 20, 3))
|
||||
|
||||
image = np.random.randint(0, 256, (3, 18, 97))
|
||||
resized_image = resize(image, (3, 20), data_format="channels_last")
|
||||
self.assertEqual(resized_image.shape, (3, 20, 3))
|
||||
|
||||
image = np.random.randint(0, 256, (18, 97, 3))
|
||||
resized_image = resize(image, (3, 20), data_format="channels_first")
|
||||
self.assertEqual(resized_image.shape, (3, 3, 20))
|
||||
|
||||
def test_resize(self):
|
||||
image = np.random.randint(0, 256, (3, 224, 224))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user