Video-LLaVa: handle any number of frames (#31221)
video-llava can handle more frames
This commit is contained in:
committed by
GitHub
parent
36ade4a32b
commit
d64e4da713
@@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
|||||||
num_images, num_image_patches, embed_dim = visual_features.shape
|
num_images, num_image_patches, embed_dim = visual_features.shape
|
||||||
batch_size, sequence_length = input_ids.shape
|
batch_size, sequence_length = input_ids.shape
|
||||||
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
|
||||||
special_vision_token = self.config.video_token_index if num_frames == 8 else self.config.image_token_index
|
special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index
|
||||||
|
|
||||||
# 1. Create a mask to know where special image tokens are
|
# 1. Create a mask to know where special image tokens are
|
||||||
special_image_token_mask = input_ids == special_vision_token
|
special_image_token_mask = input_ids == special_vision_token
|
||||||
@@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
|||||||
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
|
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
|
||||||
if num_frames != 8:
|
|
||||||
raise ValueError(f"Video pixel values should have exactly `8` frames but foung `{num_frames}`")
|
|
||||||
|
|
||||||
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
|
||||||
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
|
video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
|
||||||
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
|
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
|
||||||
else:
|
else:
|
||||||
video_outputs = None
|
video_outputs = None
|
||||||
|
num_frames = 0
|
||||||
|
|
||||||
if pixel_values_images is not None:
|
if pixel_values_images is not None:
|
||||||
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
|
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
|
||||||
@@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
image_outputs = None
|
image_outputs = None
|
||||||
|
|
||||||
return image_outputs, video_outputs
|
return image_outputs, video_outputs, num_frames
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
||||||
@@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
|||||||
|
|
||||||
# 2. Merge text and images
|
# 2. Merge text and images
|
||||||
if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1:
|
if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1:
|
||||||
image_outputs, video_outputs = self._get_vision_features(
|
image_outputs, video_outputs, num_frames = self._get_vision_features(
|
||||||
pixel_values_images=pixel_values_images,
|
pixel_values_images=pixel_values_images,
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
vision_feature_layer=vision_feature_layer,
|
vision_feature_layer=vision_feature_layer,
|
||||||
@@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
|
|||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
labels,
|
labels,
|
||||||
num_frames=8,
|
num_frames=num_frames,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
|
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
|
||||||
|
|||||||
@@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
|
|||||||
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
|
||||||
)
|
)
|
||||||
video_file = np.load(video_file)
|
video_file = np.load(video_file)
|
||||||
|
|
||||||
|
# let's expand it for 16 frames, to check model can handle any number of frames
|
||||||
|
video_file = video_file.repeat(2, 0)
|
||||||
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
|
||||||
|
|
||||||
# Make sure that `generate` works
|
# Make sure that `generate` works
|
||||||
|
|||||||
Reference in New Issue
Block a user