From d64e4da7134911416067cddc8e040423852c3938 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 4 Jun 2024 14:20:03 +0500 Subject: [PATCH] Video-LLaVa: handle any number of frames (#31221) video-llava can handle more frames --- .../models/video_llava/modeling_video_llava.py | 11 +++++------ tests/models/video_llava/test_modeling_video_llava.py | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 19829d3347..b1c0931dcf 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): num_images, num_image_patches, embed_dim = visual_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) - special_vision_token = self.config.video_token_index if num_frames == 8 else self.config.image_token_index + special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == special_vision_token @@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): # videos do not need to select features and it's always "full" (as it is done in the orig implementation) if pixel_values_videos is not None: batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape - if num_frames != 8: - raise ValueError(f"Video pixel values should have exactly `8` frames but foung `{num_frames}`") pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) video_outputs = self.video_tower(pixel_values, output_hidden_states=True) video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1) else: video_outputs = None + num_frames = 0 if pixel_values_images is not None: image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) @@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): else: image_outputs = None - return image_outputs, video_outputs + return image_outputs, video_outputs, num_frames @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): # 2. Merge text and images if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1: - image_outputs, video_outputs = self._get_vision_features( + image_outputs, video_outputs, num_frames = self._get_vision_features( pixel_values_images=pixel_values_images, pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer, @@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): input_ids, attention_mask, labels, - num_frames=8, + num_frames=num_frames, ) else: # In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 1a91a2660f..ed2f67a494 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" ) video_file = np.load(video_file) + + # let's expand it for 16 frames, to check model can handle any number of frames + video_file = video_file.repeat(2, 0) inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16) # Make sure that `generate` works