Pixtral: vectorize patch embeddings and enable tests (#35122)

* initial POC

* - batch mix feature

* fix tests

* fix tests

* make style

* do not skip and instead fix tests

* update

* return back the test

* correct text with the correct ckpt
This commit is contained in:
Raushan Turganbay
2025-01-30 12:40:18 +01:00
committed by GitHub
parent 8bc4c89ee9
commit 9725e5be2f
10 changed files with 422 additions and 545 deletions

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import time
import unittest
@@ -92,49 +91,47 @@ class PixtralImageProcessingTester:
"do_convert_rgb": self.do_convert_rgb,
}
def expected_output_image_shape(self, image):
if isinstance(image, Image.Image):
width, height = image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[:2]
elif isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
def expected_output_image_shape(self, images):
if not isinstance(images, (list, tuple)):
images = [images]
max_height = max_width = self.size.get("longest_edge")
batch_size = len(images)
return_height, return_width = 0, 0
for image in images:
if isinstance(image, Image.Image):
width, height = image.size
elif isinstance(image, np.ndarray):
height, width = image.shape[:2]
elif isinstance(image, torch.Tensor):
height, width = image.shape[-2:]
ratio = max(height / max_height, width / max_width)
if ratio > 1:
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
max_height = max_width = self.size.get("longest_edge")
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
num_height_tokens = (height - 1) // patch_height + 1
num_width_tokens = (width - 1) // patch_width + 1
ratio = max(height / max_height, width / max_width)
if ratio > 1:
height = int(np.ceil(height / ratio))
width = int(np.ceil(width / ratio))
height = num_height_tokens * patch_height
width = num_width_tokens * patch_width
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
num_height_tokens = (height - 1) // patch_height + 1
num_width_tokens = (width - 1) // patch_width + 1
return self.num_channels, height, width
return_height = max(num_height_tokens * patch_height, return_height)
return_width = max(num_width_tokens * patch_width, return_width)
return batch_size, self.num_channels, return_height, return_width
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
# Use prepare_image_inputs to make a list of list of single images
images_list = []
for _ in range(self.batch_size):
images = []
for _ in range(random.randint(1, self.max_num_images_per_sample)):
img = prepare_image_inputs(
batch_size=1,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)[0]
images.append(img)
images_list.append(images)
return images_list
images = prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
return images
@require_torch
@@ -173,23 +170,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
for image in image_inputs_list:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
image_inputs_list[0][0]
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
@@ -197,23 +189,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
for image in image_inputs_list:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
image_inputs_list[0][0]
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
@@ -221,23 +208,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
for image_inputs in image_inputs_list:
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
for image in image_inputs_list:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
image_inputs_list[0][0]
)
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
for encoded_image, image in zip(encoded_images, images):
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
@require_vision
@require_torch