CLIPFeatureExtractor should resize images with kept aspect ratio (#11994)
* Resize with kept aspect ratio * Fixed failed test * Overload center_crop and resize methods instead * resize should handle non-PIL images * update slow test * Tensor => tensor Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
@@ -154,3 +154,56 @@ class CLIPFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
|
||||
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
def center_crop(self, image, size):
|
||||
"""
|
||||
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to
|
||||
the size is given, it will be padded (so the returned result has the size asked).
|
||||
|
||||
Args:
|
||||
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
|
||||
The image to resize.
|
||||
size (:obj:`int` or :obj:`Tuple[int, int]`):
|
||||
The size to which crop the image.
|
||||
"""
|
||||
self._ensure_format_supported(image)
|
||||
if not isinstance(size, tuple):
|
||||
size = (size, size)
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
image = self.to_pil_image(image)
|
||||
|
||||
image_width, image_height = image.size
|
||||
crop_height, crop_width = size
|
||||
|
||||
crop_top = int((image_height - crop_height + 1) * 0.5)
|
||||
crop_left = int((image_width - crop_width + 1) * 0.5)
|
||||
|
||||
return image.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
|
||||
|
||||
def resize(self, image, size, resample=Image.BICUBIC):
|
||||
"""
|
||||
Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image.
|
||||
|
||||
Args:
|
||||
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
|
||||
The image to resize.
|
||||
size (:obj:`int` or :obj:`Tuple[int, int]`):
|
||||
The size to use for resizing the image. If :obj:`int` it will be resized to match the shorter side
|
||||
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`):
|
||||
The filter to user for resampling.
|
||||
"""
|
||||
self._ensure_format_supported(image)
|
||||
|
||||
if not isinstance(image, Image.Image):
|
||||
image = self.to_pil_image(image)
|
||||
if isinstance(size, tuple):
|
||||
new_w, new_h = size
|
||||
else:
|
||||
width, height = image.size
|
||||
short, long = (width, height) if width <= height else (height, width)
|
||||
if short == size:
|
||||
return image
|
||||
new_short, new_long = size, int(size * long / short)
|
||||
new_w, new_h = (new_short, new_long) if width <= height else (new_long, new_short)
|
||||
return image.resize((new_w, new_h), resample)
|
||||
|
||||
@@ -544,6 +544,7 @@ class CLIPModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
# forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
|
||||
# verify the logits
|
||||
@@ -556,6 +557,6 @@ class CLIPModelIntegrationTest(unittest.TestCase):
|
||||
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
||||
)
|
||||
|
||||
expected_logits = torch.tensor([[24.5056, 18.8076]], device=torch_device)
|
||||
expected_logits = torch.tensor([[24.5701, 19.3049]], device=torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user