From 9d2cee8b48b503e2f71d73e1fb580c3de34c92ce Mon Sep 17 00:00:00 2001 From: Tobias Norlund Date: Thu, 10 Jun 2021 15:10:41 +0200 Subject: [PATCH] CLIPFeatureExtractor should resize images with kept aspect ratio (#11994) * Resize with kept aspect ratio * Fixed failed test * Overload center_crop and resize methods instead * resize should handle non-PIL images * update slow test * Tensor => tensor Co-authored-by: patil-suraj --- .../models/clip/feature_extraction_clip.py | 53 +++++++++++++++++++ tests/test_modeling_clip.py | 5 +- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/clip/feature_extraction_clip.py b/src/transformers/models/clip/feature_extraction_clip.py index d282526253..74a70918b7 100644 --- a/src/transformers/models/clip/feature_extraction_clip.py +++ b/src/transformers/models/clip/feature_extraction_clip.py @@ -154,3 +154,56 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs + + def center_crop(self, image, size): + """ + Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to + the size is given, it will be padded (so the returned result has the size asked). + + Args: + image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`): + The image to resize. + size (:obj:`int` or :obj:`Tuple[int, int]`): + The size to which crop the image. + """ + self._ensure_format_supported(image) + if not isinstance(size, tuple): + size = (size, size) + + if not isinstance(image, Image.Image): + image = self.to_pil_image(image) + + image_width, image_height = image.size + crop_height, crop_width = size + + crop_top = int((image_height - crop_height + 1) * 0.5) + crop_left = int((image_width - crop_width + 1) * 0.5) + + return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) + + def resize(self, image, size, resample=Image.BICUBIC): + """ + Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image. + + Args: + image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`): + The image to resize. + size (:obj:`int` or :obj:`Tuple[int, int]`): + The size to use for resizing the image. If :obj:`int` it will be resized to match the shorter side + resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`): + The filter to user for resampling. + """ + self._ensure_format_supported(image) + + if not isinstance(image, Image.Image): + image = self.to_pil_image(image) + if isinstance(size, tuple): + new_w, new_h = size + else: + width, height = image.size + short, long = (width, height) if width <= height else (height, width) + if short == size: + return image + new_short, new_long = size, int(size * long / short) + new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short) + return image.resize((new_w, new_h), resample) diff --git a/tests/test_modeling_clip.py b/tests/test_modeling_clip.py index 8dc0ab214c..2a8f05d7a6 100644 --- a/tests/test_modeling_clip.py +++ b/tests/test_modeling_clip.py @@ -544,7 +544,8 @@ class CLIPModelIntegrationTest(unittest.TestCase): ).to(torch_device) # forward pass - outputs = model(**inputs) + with torch.no_grad(): + outputs = model(**inputs) # verify the logits self.assertEqual( @@ -556,6 +557,6 @@ class CLIPModelIntegrationTest(unittest.TestCase): torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), ) - expected_logits = torch.tensor([[24.5056, 18.8076]], device=torch_device) + expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))