Fix smart resize (#38706)

* Fix smart_resize bug

* Add smart_resize test

* Remove unnecessary error checking

* Fix smart_resize tests

---------

Co-authored-by: Richard Dong <rdong@rdong.c.groq-143208.internal>
This commit is contained in:
rdonggroq
2025-06-10 04:59:22 -04:00
committed by GitHub
parent 81799d8b55
commit afdb821318
3 changed files with 49 additions and 40 deletions

View File

@@ -81,9 +81,7 @@ def smart_resize(
3. The aspect ratio of the image is maintained as closely as possible. 3. The aspect ratio of the image is maintained as closely as possible.
""" """
if height < factor or width < factor: if max(height, width) / min(height, width) > 200:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError( raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
) )
@@ -91,8 +89,8 @@ def smart_resize(
w_bar = round(width / factor) * factor w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels: if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels) beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = math.floor(width / beta / factor) * factor w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels: elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width)) beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor h_bar = math.ceil(height * beta / factor) * factor

View File

@@ -64,9 +64,7 @@ def smart_resize(
3. The aspect ratio of the image is maintained as closely as possible. 3. The aspect ratio of the image is maintained as closely as possible.
""" """
if height < factor or width < factor: if max(height, width) / min(height, width) > 200:
raise ValueError(f"height:{height} and width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError( raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
) )
@@ -74,8 +72,8 @@ def smart_resize(
w_bar = round(width / factor) * factor w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels: if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels) beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = math.floor(width / beta / factor) * factor w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels: elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width)) beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor h_bar = math.ceil(height * beta / factor) * factor

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools
import tempfile import tempfile
import unittest import unittest
@@ -169,18 +170,18 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertIsInstance(image[0], Image.Image) self.assertIsInstance(image[0], Image.Image)
# Test not batched input # Test not batched input
prcocess_out = image_processing(image_inputs[0], return_tensors="pt") process_out = image_processing(image_inputs[0], return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (4900, 1176) expected_output_image_shape = (4900, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) expected_image_grid_thws = torch.Tensor([[1, 70, 70]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) self.assertTrue((image_grid_thws == expected_image_grid_thws).all())
# Test batched # Test batched
prcocess_out = image_processing(image_inputs, return_tensors="pt") process_out = image_processing(image_inputs, return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (34300, 1176) expected_output_image_shape = (34300, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@@ -196,18 +197,18 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertIsInstance(image[0], np.ndarray) self.assertIsInstance(image[0], np.ndarray)
# Test not batched input # Test not batched input
prcocess_out = image_processing(image_inputs[0], return_tensors="pt") process_out = image_processing(image_inputs[0], return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (4900, 1176) expected_output_image_shape = (4900, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) expected_image_grid_thws = torch.Tensor([[1, 70, 70]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) self.assertTrue((image_grid_thws == expected_image_grid_thws).all())
# Test batched # Test batched
prcocess_out = image_processing(image_inputs, return_tensors="pt") process_out = image_processing(image_inputs, return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (34300, 1176) expected_output_image_shape = (34300, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@@ -224,18 +225,18 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
self.assertIsInstance(image[0], torch.Tensor) self.assertIsInstance(image[0], torch.Tensor)
# Test not batched input # Test not batched input
prcocess_out = image_processing(image_inputs[0], return_tensors="pt") process_out = image_processing(image_inputs[0], return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (4900, 1176) expected_output_image_shape = (4900, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]]) expected_image_grid_thws = torch.Tensor([[1, 70, 70]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
self.assertTrue((image_grid_thws == expected_image_grid_thws).all()) self.assertTrue((image_grid_thws == expected_image_grid_thws).all())
# Test batched # Test batched
prcocess_out = image_processing(image_inputs, return_tensors="pt") process_out = image_processing(image_inputs, return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (34300, 1176) expected_output_image_shape = (34300, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@@ -251,9 +252,9 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
# Test batched as a list of images # Test batched as a list of images
prcocess_out = image_processing(image_inputs, return_tensors="pt") process_out = image_processing(image_inputs, return_tensors="pt")
encoded_images = prcocess_out.pixel_values encoded_images = process_out.pixel_values
image_grid_thws = prcocess_out.image_grid_thw image_grid_thws = process_out.image_grid_thw
expected_output_image_shape = (34300, 1176) expected_output_image_shape = (34300, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@@ -261,9 +262,9 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
# Test batched as a nested list of images, where each sublist is one batch # Test batched as a nested list of images, where each sublist is one batch
image_inputs_nested = image_inputs[:3] + image_inputs[3:] image_inputs_nested = image_inputs[:3] + image_inputs[3:]
prcocess_out = image_processing(image_inputs_nested, return_tensors="pt") process_out = image_processing(image_inputs_nested, return_tensors="pt")
encoded_images_nested = prcocess_out.pixel_values encoded_images_nested = process_out.pixel_values
image_grid_thws_nested = prcocess_out.image_grid_thw image_grid_thws_nested = process_out.image_grid_thw
expected_output_image_shape = (34300, 1176) expected_output_image_shape = (34300, 1176)
expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7) expected_image_grid_thws = torch.Tensor([[1, 70, 70]] * 7)
self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape) self.assertEqual(tuple(encoded_images_nested.shape), expected_output_image_shape)
@@ -281,8 +282,8 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
for num_frames, expected_dims in expected_dims_by_frames.items(): for num_frames, expected_dims in expected_dims_by_frames.items():
image_processor_tester = Qwen2VLImageProcessingTester(self, num_frames=num_frames) image_processor_tester = Qwen2VLImageProcessingTester(self, num_frames=num_frames)
video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True) video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True)
prcocess_out = image_processing(None, videos=video_inputs, return_tensors="pt") process_out = image_processing(None, videos=video_inputs, return_tensors="pt")
encoded_video = prcocess_out.pixel_values_videos encoded_video = process_out.pixel_values_videos
expected_output_video_shape = (expected_dims, 1176) expected_output_video_shape = (expected_dims, 1176)
self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape) self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape)
@@ -293,8 +294,8 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
for patch_size in (1, 3, 5, 7): for patch_size in (1, 3, 5, 7):
image_processor_tester = Qwen2VLImageProcessingTester(self, patch_size=patch_size) image_processor_tester = Qwen2VLImageProcessingTester(self, patch_size=patch_size)
video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True) video_inputs = image_processor_tester.prepare_video_inputs(equal_resolution=True)
prcocess_out = image_processing(None, videos=video_inputs, return_tensors="pt") process_out = image_processing(None, videos=video_inputs, return_tensors="pt")
encoded_video = prcocess_out.pixel_values_videos encoded_video = process_out.pixel_values_videos
expected_output_video_shape = (171500, 1176) expected_output_video_shape = (171500, 1176)
self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape) self.assertEqual(tuple(encoded_video.shape), expected_output_video_shape)
@@ -308,9 +309,21 @@ class Qwen2VLImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
) )
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True) image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
prcocess_out = image_processor_loaded(image_inputs, return_tensors="pt") process_out = image_processor_loaded(image_inputs, return_tensors="pt")
expected_output_video_shape = [112, 1176] expected_output_video_shape = [112, 1176]
self.assertListEqual(list(prcocess_out.pixel_values.shape), expected_output_video_shape) self.assertListEqual(list(process_out.pixel_values.shape), expected_output_video_shape)
def test_custom_pixels(self):
pixel_choices = frozenset(itertools.product((100, 150, 200, 20000), (100, 150, 200, 20000)))
for image_processing_class in self.image_processor_list:
image_processor_dict = self.image_processor_dict.copy()
for a_pixels, b_pixels in pixel_choices:
image_processor_dict["min_pixels"] = min(a_pixels, b_pixels)
image_processor_dict["max_pixels"] = max(a_pixels, b_pixels)
image_processor = image_processing_class(**image_processor_dict)
image_inputs = self.image_processor_tester.prepare_image_inputs()
# Just checking that it doesn't raise an error
image_processor(image_inputs, return_tensors="pt")
def test_temporal_padding(self): def test_temporal_padding(self):
for image_processing_class in self.image_processor_list: for image_processing_class in self.image_processor_list: