committed by
GitHub
parent
b0735dc0c1
commit
19fdb75cf0
@@ -696,11 +696,13 @@ def group_videos_by_shape(
|
|||||||
grouped_videos_index = {}
|
grouped_videos_index = {}
|
||||||
for i, video in enumerate(videos):
|
for i, video in enumerate(videos):
|
||||||
shape = video.shape[-2::]
|
shape = video.shape[-2::]
|
||||||
|
num_frames = video.shape[-4] # video format BTCHW
|
||||||
|
shape = (num_frames, *shape)
|
||||||
if shape not in grouped_videos:
|
if shape not in grouped_videos:
|
||||||
grouped_videos[shape] = []
|
grouped_videos[shape] = []
|
||||||
grouped_videos[shape].append(video)
|
grouped_videos[shape].append(video)
|
||||||
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
|
grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
|
||||||
# stack videos with the same shape
|
# stack videos with the same size and number of frames
|
||||||
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
|
grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
|
||||||
return grouped_videos, grouped_videos_index
|
return grouped_videos, grouped_videos_index
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from transformers.testing_utils import (
|
|||||||
require_torchvision,
|
require_torchvision,
|
||||||
require_vision,
|
require_vision,
|
||||||
)
|
)
|
||||||
from transformers.video_utils import make_batched_videos
|
from transformers.video_utils import group_videos_by_shape, make_batched_videos, reorder_videos
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -43,9 +43,9 @@ if is_vision_available():
|
|||||||
from transformers.video_utils import VideoMetadata, load_video
|
from transformers.video_utils import VideoMetadata, load_video
|
||||||
|
|
||||||
|
|
||||||
def get_random_video(height, width, return_torch=False):
|
def get_random_video(height, width, num_frames=8, return_torch=False):
|
||||||
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
random_frame = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||||
video = np.array(([random_frame] * 8))
|
video = np.array(([random_frame] * num_frames))
|
||||||
if return_torch:
|
if return_torch:
|
||||||
# move channel first
|
# move channel first
|
||||||
return torch.from_numpy(video).permute(0, 3, 1, 2)
|
return torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||||
@@ -189,6 +189,53 @@ class BaseVideoProcessorTester(unittest.TestCase):
|
|||||||
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
|
rgb_video = video_processor.convert_to_rgb(torch.cat([video, video[:, :1]], dim=1))
|
||||||
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
self.assertEqual(rgb_video.shape, (8, 3, 20, 20))
|
||||||
|
|
||||||
|
def test_group_and_reorder_videos(self):
|
||||||
|
"""Tests that videos can be grouped by frame size and number of frames"""
|
||||||
|
video_1 = get_random_video(20, 20, num_frames=3, return_torch=True)
|
||||||
|
video_2 = get_random_video(20, 20, num_frames=5, return_torch=True)
|
||||||
|
|
||||||
|
# Group two videos of same size but different number of frames
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2])
|
||||||
|
self.assertEqual(len(grouped_videos), 2)
|
||||||
|
|
||||||
|
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||||
|
self.assertTrue(len(regrouped_videos), 2)
|
||||||
|
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||||
|
|
||||||
|
# Group two videos of different size but same number of frames
|
||||||
|
video_3 = get_random_video(15, 20, num_frames=3, return_torch=True)
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_3])
|
||||||
|
self.assertEqual(len(grouped_videos), 2)
|
||||||
|
|
||||||
|
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||||
|
self.assertTrue(len(regrouped_videos), 2)
|
||||||
|
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||||
|
|
||||||
|
# Group all three videos where some have same size or same frame count
|
||||||
|
# But since none have frames and sizes identical, we'll have 3 groups
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_2, video_3])
|
||||||
|
self.assertEqual(len(grouped_videos), 3)
|
||||||
|
|
||||||
|
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||||
|
self.assertTrue(len(regrouped_videos), 3)
|
||||||
|
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||||
|
|
||||||
|
# Group if we had some videos with identical shapes
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_3])
|
||||||
|
self.assertEqual(len(grouped_videos), 2)
|
||||||
|
|
||||||
|
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||||
|
self.assertTrue(len(regrouped_videos), 2)
|
||||||
|
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||||
|
|
||||||
|
# Group if we had all videos with identical shapes
|
||||||
|
grouped_videos, grouped_videos_index = group_videos_by_shape([video_1, video_1, video_1])
|
||||||
|
self.assertEqual(len(grouped_videos), 1)
|
||||||
|
|
||||||
|
regrouped_videos = reorder_videos(grouped_videos, grouped_videos_index)
|
||||||
|
self.assertTrue(len(regrouped_videos), 1)
|
||||||
|
self.assertEqual(video_1.shape, regrouped_videos[0].shape)
|
||||||
|
|
||||||
|
|
||||||
@require_vision
|
@require_vision
|
||||||
@require_av
|
@require_av
|
||||||
|
|||||||
Reference in New Issue
Block a user