Fix video batching to videollava (#32139)
--------- Co-authored-by: Merve Noyan <mervenoyan@Merve-MacBook-Pro.local>
This commit is contained in:
@@ -55,8 +55,11 @@ def make_batched_videos(videos) -> List[VideoInput]:
|
|||||||
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]) and len(videos[0].shape) == 4:
|
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]):
|
||||||
return [list(video) for video in videos]
|
if isinstance(videos[0], PIL.Image.Image):
|
||||||
|
return [videos]
|
||||||
|
elif len(videos[0].shape) == 4:
|
||||||
|
return [list(video) for video in videos]
|
||||||
|
|
||||||
elif is_valid_image(videos) and len(videos.shape) == 4:
|
elif is_valid_image(videos) and len(videos.shape) == 4:
|
||||||
return [list(videos)]
|
return [list(videos)]
|
||||||
|
|||||||
@@ -97,8 +97,7 @@ class VideoLlavaImageProcessingTester(unittest.TestCase):
|
|||||||
torchify=torchify,
|
torchify=torchify,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_video_inputs(self, equal_resolution=False, torchify=False):
|
def prepare_video_inputs(self, equal_resolution=False, numpify=False, torchify=False):
|
||||||
numpify = not torchify
|
|
||||||
images = prepare_image_inputs(
|
images = prepare_image_inputs(
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_channels=self.num_channels,
|
num_channels=self.num_channels,
|
||||||
@@ -108,15 +107,19 @@ class VideoLlavaImageProcessingTester(unittest.TestCase):
|
|||||||
numpify=numpify,
|
numpify=numpify,
|
||||||
torchify=torchify,
|
torchify=torchify,
|
||||||
)
|
)
|
||||||
|
|
||||||
# let's simply copy the frames to fake a long video-clip
|
# let's simply copy the frames to fake a long video-clip
|
||||||
videos = []
|
if numpify or torchify:
|
||||||
for image in images:
|
videos = []
|
||||||
if numpify:
|
for image in images:
|
||||||
video = image[None, ...].repeat(8, 0)
|
if numpify:
|
||||||
else:
|
video = image[None, ...].repeat(8, 0)
|
||||||
video = image[None, ...].repeat(8, 1, 1, 1)
|
else:
|
||||||
videos.append(video)
|
video = image[None, ...].repeat(8, 1, 1, 1)
|
||||||
|
videos.append(video)
|
||||||
|
else:
|
||||||
|
videos = []
|
||||||
|
for pil_image in images:
|
||||||
|
videos.append([pil_image] * 8)
|
||||||
|
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
@@ -197,7 +200,7 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
|||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
# create random numpy tensors
|
# create random numpy tensors
|
||||||
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
video_inputs = self.image_processor_tester.prepare_video_inputs(numpify=True, equal_resolution=True)
|
||||||
for video in video_inputs:
|
for video in video_inputs:
|
||||||
self.assertIsInstance(video, np.ndarray)
|
self.assertIsInstance(video, np.ndarray)
|
||||||
|
|
||||||
@@ -211,6 +214,24 @@ class VideoLlavaImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase)
|
|||||||
expected_output_video_shape = (5, 8, 3, 18, 18)
|
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||||
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
|
||||||
|
def test_call_pil_videos(self):
|
||||||
|
# Initialize image_processing
|
||||||
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
# the inputs come in list of lists batched format
|
||||||
|
video_inputs = self.image_processor_tester.prepare_video_inputs(equal_resolution=True)
|
||||||
|
for video in video_inputs:
|
||||||
|
self.assertIsInstance(video[0], Image.Image)
|
||||||
|
|
||||||
|
# Test not batched input
|
||||||
|
encoded_videos = image_processing(images=None, videos=video_inputs[0], return_tensors="pt").pixel_values_videos
|
||||||
|
expected_output_video_shape = (1, 8, 3, 18, 18)
|
||||||
|
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
|
||||||
|
# Test batched
|
||||||
|
encoded_videos = image_processing(images=None, videos=video_inputs, return_tensors="pt").pixel_values_videos
|
||||||
|
expected_output_video_shape = (5, 8, 3, 18, 18)
|
||||||
|
self.assertEqual(tuple(encoded_videos.shape), expected_output_video_shape)
|
||||||
|
|
||||||
def test_call_pytorch(self):
|
def test_call_pytorch(self):
|
||||||
# Initialize image_processing
|
# Initialize image_processing
|
||||||
image_processing = self.image_processing_class(**self.image_processor_dict)
|
image_processing = self.image_processing_class(**self.image_processor_dict)
|
||||||
|
|||||||
Reference in New Issue
Block a user