Image transforms add center crop (#19718)
* Add center crop to transforms library * Return PIL images if PIL image input by default * Fixup and add docstring * Trigger CI * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/image_transforms.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * PR comments - move comments; unindent Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -35,6 +35,7 @@ if is_vision_available():
|
||||
import PIL.Image
|
||||
|
||||
from transformers.image_transforms import (
|
||||
center_crop,
|
||||
get_resize_output_image_size,
|
||||
normalize,
|
||||
resize,
|
||||
@@ -195,3 +196,26 @@ class ImageTransformsTester(unittest.TestCase):
|
||||
self.assertIsInstance(normalized_image, np.ndarray)
|
||||
self.assertEqual(normalized_image.shape, (3, 224, 224))
|
||||
self.assertTrue(np.allclose(normalized_image, expected_image))
|
||||
|
||||
def test_center_crop(self):
|
||||
image = np.random.randint(0, 256, (3, 224, 224))
|
||||
|
||||
# Test that exception is raised if inputs are incorrect
|
||||
with self.assertRaises(ValueError):
|
||||
center_crop(image, 10)
|
||||
|
||||
# Test result is correct - output data format is channels_first and center crop
|
||||
# correctly computed
|
||||
expected_image = image[:, 52:172, 82:142].transpose(1, 2, 0)
|
||||
cropped_image = center_crop(image, (120, 60), data_format="channels_last")
|
||||
self.assertIsInstance(cropped_image, np.ndarray)
|
||||
self.assertEqual(cropped_image.shape, (120, 60, 3))
|
||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||
|
||||
# Test that image is padded with zeros if crop size is larger than image size
|
||||
expected_image = np.zeros((300, 260, 3))
|
||||
expected_image[38:262, 18:242, :] = image.transpose((1, 2, 0))
|
||||
cropped_image = center_crop(image, (300, 260), data_format="channels_last")
|
||||
self.assertIsInstance(cropped_image, np.ndarray)
|
||||
self.assertEqual(cropped_image.shape, (300, 260, 3))
|
||||
self.assertTrue(np.allclose(cropped_image, expected_image))
|
||||
|
||||
Reference in New Issue
Block a user