Fix donut image processor (#20625)

* fix donut image processor

* Update test values

* Apply lower bound on resizing size

* Add in missing size param

* Resolve resize channel_dimension bug

* Update src/transformers/image_transforms.py
This commit is contained in:
amyeroberts
2022-12-08 19:10:40 +00:00
committed by GitHub
parent e3cc4487fe
commit cf1b8c34cc
4 changed files with 58 additions and 14 deletions

View File

@@ -836,7 +836,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size([1, 1, 57532])
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device)
expected_slice = torch.tensor([24.3873, -6.4491, 32.5394]).to(torch_device)
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
# step 2: generation
@@ -872,7 +872,7 @@ class DonutModelIntegrationTest(unittest.TestCase):
self.assertEqual(len(outputs.scores), 11)
self.assertTrue(
torch.allclose(
outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4
outputs.scores[0][0, :3], torch.tensor([5.6019, -3.5070, 13.7123], device=torch_device), atol=1e-4
)
)

View File

@@ -184,6 +184,25 @@ class ImageTransformsTester(unittest.TestCase):
image = np.random.randint(0, 256, (3, 50, 40))
self.assertEqual(get_resize_output_image_size(image, 20, default_to_square=False, max_size=22), (22, 17))
# Test correct channel dimension is returned if output size if height == 3
# Defaults to input format - channels first
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 3, 20))
# Defaults to input format - channels last
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20))
self.assertEqual(resized_image.shape, (3, 20, 3))
image = np.random.randint(0, 256, (3, 18, 97))
resized_image = resize(image, (3, 20), data_format="channels_last")
self.assertEqual(resized_image.shape, (3, 20, 3))
image = np.random.randint(0, 256, (18, 97, 3))
resized_image = resize(image, (3, 20), data_format="channels_first")
self.assertEqual(resized_image.shape, (3, 3, 20))
def test_resize(self):
image = np.random.randint(0, 256, (3, 224, 224))