Pixtral: vectorize patch embeddings and enable tests (#35122)
* initial POC * - batch mix feature * fix tests * fix tests * make style * do not skip and instead fix tests * update * return back the test * correct text with the correct ckpt
This commit is contained in:
committed by
GitHub
parent
8bc4c89ee9
commit
9725e5be2f
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -92,49 +91,47 @@ class PixtralImageProcessingTester:
|
||||
"do_convert_rgb": self.do_convert_rgb,
|
||||
}
|
||||
|
||||
def expected_output_image_shape(self, image):
|
||||
if isinstance(image, Image.Image):
|
||||
width, height = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
height, width = image.shape[:2]
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height, width = image.shape[-2:]
|
||||
def expected_output_image_shape(self, images):
|
||||
if not isinstance(images, (list, tuple)):
|
||||
images = [images]
|
||||
|
||||
max_height = max_width = self.size.get("longest_edge")
|
||||
batch_size = len(images)
|
||||
return_height, return_width = 0, 0
|
||||
for image in images:
|
||||
if isinstance(image, Image.Image):
|
||||
width, height = image.size
|
||||
elif isinstance(image, np.ndarray):
|
||||
height, width = image.shape[:2]
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
ratio = max(height / max_height, width / max_width)
|
||||
if ratio > 1:
|
||||
height = int(np.ceil(height / ratio))
|
||||
width = int(np.ceil(width / ratio))
|
||||
max_height = max_width = self.size.get("longest_edge")
|
||||
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
num_height_tokens = (height - 1) // patch_height + 1
|
||||
num_width_tokens = (width - 1) // patch_width + 1
|
||||
ratio = max(height / max_height, width / max_width)
|
||||
if ratio > 1:
|
||||
height = int(np.ceil(height / ratio))
|
||||
width = int(np.ceil(width / ratio))
|
||||
|
||||
height = num_height_tokens * patch_height
|
||||
width = num_width_tokens * patch_width
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
num_height_tokens = (height - 1) // patch_height + 1
|
||||
num_width_tokens = (width - 1) // patch_width + 1
|
||||
|
||||
return self.num_channels, height, width
|
||||
return_height = max(num_height_tokens * patch_height, return_height)
|
||||
return_width = max(num_width_tokens * patch_width, return_width)
|
||||
|
||||
return batch_size, self.num_channels, return_height, return_width
|
||||
|
||||
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||
# Use prepare_image_inputs to make a list of list of single images
|
||||
|
||||
images_list = []
|
||||
for _ in range(self.batch_size):
|
||||
images = []
|
||||
for _ in range(random.randint(1, self.max_num_images_per_sample)):
|
||||
img = prepare_image_inputs(
|
||||
batch_size=1,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)[0]
|
||||
images.append(img)
|
||||
images_list.append(images)
|
||||
return images_list
|
||||
images = prepare_image_inputs(
|
||||
batch_size=self.batch_size,
|
||||
num_channels=self.num_channels,
|
||||
min_resolution=self.min_resolution,
|
||||
max_resolution=self.max_resolution,
|
||||
equal_resolution=equal_resolution,
|
||||
numpify=numpify,
|
||||
torchify=torchify,
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -173,23 +170,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PIL images
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs()
|
||||
for image_inputs in image_inputs_list:
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
for image in image_inputs_list:
|
||||
self.assertIsInstance(image, Image.Image)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
image_inputs_list[0][0]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||
for encoded_image, image in zip(encoded_images, images):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_numpy(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
@@ -197,23 +189,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(numpify=True)
|
||||
for image_inputs in image_inputs_list:
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
for image in image_inputs_list:
|
||||
self.assertIsInstance(image, np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
image_inputs_list[0][0]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||
for encoded_image, image in zip(encoded_images, images):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
|
||||
self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
def test_call_pytorch(self):
|
||||
for image_processing_class in self.image_processor_list:
|
||||
@@ -221,23 +208,18 @@ class PixtralImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
image_processing = image_processing_class(**self.image_processor_dict)
|
||||
# create random PyTorch tensors
|
||||
image_inputs_list = self.image_processor_tester.prepare_image_inputs(torchify=True)
|
||||
for image_inputs in image_inputs_list:
|
||||
for image in image_inputs:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
for image in image_inputs_list:
|
||||
self.assertIsInstance(image, torch.Tensor)
|
||||
|
||||
# Test not batched input
|
||||
encoded_images = image_processing(image_inputs_list[0][0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(
|
||||
image_inputs_list[0][0]
|
||||
)
|
||||
self.assertEqual(tuple(encoded_images[0][0].shape), expected_output_image_shape)
|
||||
encoded_images = image_processing(image_inputs_list[0], return_tensors="pt").pixel_values
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list[0])
|
||||
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
# Test batched
|
||||
batch_encoded_images = image_processing(image_inputs_list, return_tensors="pt").pixel_values
|
||||
for encoded_images, images in zip(batch_encoded_images, image_inputs_list):
|
||||
for encoded_image, image in zip(encoded_images, images):
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image)
|
||||
self.assertEqual(tuple(encoded_image.shape), expected_output_image_shape)
|
||||
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs_list)
|
||||
self.assertEqual(tuple(batch_encoded_images.shape), expected_output_image_shape)
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user