Input data format (#25464)
* Add copied from statements for image processors * Move out rescale and normalize to base image processor * Remove rescale and normalize from vit (post rebase) * Update docstrings and tidy up * PR comments * Add input_data_format as preprocess argument * Resolve tests and tidy up * Remove num_channels argument * Update doc strings -> default ints not in code formatting
This commit is contained in:
@@ -165,6 +165,33 @@ class VideoMAEImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
|
||||
tuple(encoded_videos.shape), (self.image_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
|
||||
def test_call_numpy_4_channels(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
# create random numpy tensors
|
||||
self.image_processor_tester.num_channels = 4
|
||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=False, numpify=True)
|
||||
for video in video_inputs:
|
||||
self.assertIsInstance(video, list)
|
||||
self.assertIsInstance(video[0], np.ndarray)
|
||||
|
||||
# Test not batched input
|
||||
encoded_videos = image_processing(
|
||||
video_inputs[0], return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_video_shape = self.image_processor_tester.expected_output_image_shape([encoded_videos[0]])
|
||||
self.assertEqual(tuple(encoded_videos.shape), (1, *expected_output_video_shape))
|
||||
|
||||
# Test batched
|
||||
encoded_videos = image_processing(
|
||||
video_inputs, return_tensors="pt", image_mean=0, image_std=1, input_data_format="channels_first"
|
||||
).pixel_values
|
||||
expected_output_video_shape = self.image_processor_tester.expected_output_image_shape(encoded_videos)
|
||||
self.assertEqual(
|
||||
tuple(encoded_videos.shape), (self.image_processor_tester.batch_size, *expected_output_video_shape)
|
||||
)
|
||||
self.image_processor_tester.num_channels = 3
|
||||
|
||||
def test_call_pytorch(self):
|
||||
# Initialize image_processing
|
||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||
|
||||
Reference in New Issue
Block a user