CLIPFeatureExtractor should resize images with kept aspect ratio (#11994)

* Resize with kept aspect ratio

* Fixed failed test

* Overload center_crop and resize methods instead

* resize should handle non-PIL images

* update slow test

* Tensor => tensor

Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
Tobias Norlund
2021-06-10 15:10:41 +02:00
committed by GitHub
parent 472a867626
commit 9d2cee8b48
2 changed files with 56 additions and 2 deletions

View File

@@ -154,3 +154,56 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs return encoded_inputs
def center_crop(self, image, size):
"""
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to
the size is given, it will be padded (so the returned result has the size asked).
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to which crop the image.
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
size = (size, size)
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
image_width, image_height = image.size
crop_height, crop_width = size
crop_top = int((image_height - crop_height + 1) * 0.5)
crop_left = int((image_width - crop_width + 1) * 0.5)
return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
def resize(self, image, size, resample=Image.BICUBIC):
"""
Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image.
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to use for resizing the image. If :obj:`int` it will be resized to match the shorter side
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`):
The filter to user for resampling.
"""
self._ensure_format_supported(image)
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
if isinstance(size, tuple):
new_w, new_h = size
else:
width, height = image.size
short, long = (width, height) if width <= height else (height, width)
if short == size:
return image
new_short, new_long = size, int(size * long / short)
new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short)
return image.resize((new_w, new_h), resample)

View File

@@ -544,6 +544,7 @@ class CLIPModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
# forward pass # forward pass
with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
# verify the logits # verify the logits
@@ -556,6 +557,6 @@ class CLIPModelIntegrationTest(unittest.TestCase):
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])), torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
) )
expected_logits = torch.tensor([[24.5056, 18.8076]], device=torch_device) expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))