committed by
GitHub
parent
b0735dc0c1
commit
19fdb75cf0
@@ -30,7 +30,7 @@ from transformers.testing_utils import (
|
||||
require_torchvision,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.video_utils import make_batched_videos
|
||||
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -43,9 +43,9 @@ if is_vision_available():
|
||||
from transformers.video_utils import VideoMetadata, load_video
|
||||
|
||||
|
||||
def get_random_video(height, width, return_torch=False):
|
||||
def get_random_video(height, width, num_frames=8, return_torch=False):
|
||||
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
video = np.array(([random_frame] * 8))
|
||||
video = np.array(([random_frame] * num_frames))
|
||||
if return_torch:
|
||||
# move channel first
|
||||
return torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
@@ -189,6 +189,53 @@ class BaseVideoProcessorTester(unittest.TestCase):
|
||||
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
|
||||
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
||||
|
||||
def test_group_and_reorder_videos(self):
|
||||
"""Tests that videos can be grouped by frame size and number of frames"""
|
||||
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
|
||||
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
|
||||
|
||||
# Group two videos of same size but different number of frames
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
|
||||
self.assertEqual(len(grouped_videos), 2)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 2)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group two videos of different size but same number of frames
|
||||
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
|
||||
self.assertEqual(len(grouped_videos), 2)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 2)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group all three videos where some have same size or same frame count
|
||||
# But since none have frames and sizes identical, we'll have 3 groups
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
|
||||
self.assertEqual(len(grouped_videos), 3)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 3)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group if we had some videos with identical shapes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
|
||||
self.assertEqual(len(grouped_videos), 2)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 2)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
# Group if we had all videos with identical shapes
|
||||
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
|
||||
self.assertEqual(len(grouped_videos), 1)
|
||||
|
||||
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||
self.assertTrue(len(regrouped_videos), 1)
|
||||
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_av
|
||||
|
||||
Reference in New Issue
Block a user