[video utils] group and reorder by number of frames (#38374)

fix
This commit is contained in:
Raushan Turganbay
2025-05-27 11:32:33 +02:00
committed by GitHub
parent b0735dc0c1
commit 19fdb75cf0
2 changed files with 53 additions and 4 deletions

View File

@@ -696,11 +696,13 @@ def group_videos_by_shape(
grouped_videos_index = {} grouped_videos_index = {}
for i, video in enumerate(videos): for i, video in enumerate(videos):
shape = video.shape[-2::] shape = video.shape[-2::]
num_frames = video.shape[-4] # video format BTCHW
shape = (num_frames, *shape)
if shape not in grouped_videos: if shape not in grouped_videos:
grouped_videos[shape] = [] grouped_videos[shape] = []
grouped_videos[shape].append(video) grouped_videos[shape].append(video)
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1) grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
# stack videos with the same shape # stack videos with the same size and number of frames
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()} grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
return grouped_videos, grouped_videos_index return grouped_videos, grouped_videos_index

View File

@@ -30,7 +30,7 @@ from transformers.testing_utils import (
require_torchvision, require_torchvision,
require_vision, require_vision,
) )
from transformers.video_utils import make_batched_videos from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
if is_torch_available(): if is_torch_available():
@@ -43,9 +43,9 @@ if is_vision_available():
from transformers.video_utils import VideoMetadata, load_video from transformers.video_utils import VideoMetadata, load_video
def get_random_video(height, width, return_torch=False): def get_random_video(height, width, num_frames=8, return_torch=False):
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
video = np.array(([random_frame] * 8)) video = np.array(([random_frame] * num_frames))
if return_torch: if return_torch:
# move channel first # move channel first
return torch.from_numpy(video).permute(0, 3, 1, 2) return torch.from_numpy(video).permute(0, 3, 1, 2)
@@ -189,6 +189,53 @@ class BaseVideoProcessorTester(unittest.TestCase):
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1)) rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
self.assertEqual(rgb_video.shape, (8, 3, 20, 20)) self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
def test_group_and_reorder_videos(self):
"""Tests that videos can be grouped by frame size and number of frames"""
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
# Group two videos of same size but different number of frames
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group two videos of different size but same number of frames
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group all three videos where some have same size or same frame count
# But since none have frames and sizes identical, we'll have 3 groups
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
self.assertEqual(len(grouped_videos), 3)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 3)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had some videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
self.assertEqual(len(grouped_videos), 2)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 2)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
# Group if we had all videos with identical shapes
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
self.assertEqual(len(grouped_videos), 1)
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
self.assertTrue(len(regrouped_videos), 1)
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
@require_vision @require_vision
@require_av @require_av