Accept batched tensor of images as input to image processor (#21144)

* Accept a batched tensor of images as input

* Add to all image processors

* Update oneformer
This commit is contained in:
amyeroberts
2023-01-26 10:15:26 +00:00
committed by GitHub
parent 6f3faf3863
commit d18a1cba24
34 changed files with 220 additions and 105 deletions

View File

@@ -20,7 +20,7 @@ import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis
from transformers.image_utils import ChannelDimension, get_channel_dimension_axis, make_list_of_images
from transformers.testing_utils import require_torch, require_vision
@@ -102,6 +102,58 @@ class ImageFeatureExtractionTester(unittest.TestCase):
self.assertEqual(array5.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array5, array1))
def test_make_list_of_images_numpy(self):
# Test a single image is converted to a list of 1 image
images = np.random.randint(0, 256, (16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a batch of images is converted to a list of images
images = np.random.randint(0, 256, (4, 16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test a list of images is not modified
images = [np.random.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test batched masks with no channel dimension are converted to a list of masks
masks = np.random.randint(0, 2, (4, 16, 32))
masks_list = make_list_of_images(masks, expected_ndims=2)
self.assertEqual(len(masks_list), 4)
self.assertTrue(np.array_equal(masks_list[0], masks[0]))
self.assertIsInstance(masks_list, list)
@require_torch
def test_make_list_of_images_torch(self):
# Test a single image is converted to a list of 1 image
images = torch.randint(0, 256, (16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 1)
self.assertTrue(np.array_equal(images_list[0], images))
self.assertIsInstance(images_list, list)
# Test a batch of images is converted to a list of images
images = torch.randint(0, 256, (4, 16, 32, 3))
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
# Test a list of images is left unchanged
images = [torch.randint(0, 256, (16, 32, 3)) for _ in range(4)]
images_list = make_list_of_images(images)
self.assertEqual(len(images_list), 4)
self.assertTrue(np.array_equal(images_list[0], images[0]))
self.assertIsInstance(images_list, list)
@require_torch
def test_conversion_torch_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()