From e97c76000656649559a78dffce678ea42f3fcb7f Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Thu, 27 Mar 2025 14:46:11 +0100 Subject: [PATCH] [chat templates} support loading audio from video (#36955) * add audio from video * typos * delete print * comments --- src/transformers/processing_utils.py | 118 +++++++++++++++------------ tests/test_processing_common.py | 69 +++++++++++++++- 2 files changed, 129 insertions(+), 58 deletions(-) diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 7551714773..2b2d158179 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -23,7 +23,7 @@ import sys import typing import warnings from pathlib import Path -from typing import Any, Callable, Optional, TypedDict, Union +from typing import Any, Callable, Dict, List, Optional, TypedDict, Union import numpy as np import typing_extensions @@ -386,14 +386,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): return_assistant_tokens_mask: Optional[bool] = False -class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False): +class ChatTemplateLoadKwargs(TypedDict, total=False): """ - Keyword arguments for processor chat templates. + Keyword arguments used to load multimodal data in processor chat templates. - tokenize (`bool`, *optional*, defaults to `False`): - Whether to tokenize the output or not. - return_dict (`bool`, defaults to `False`): - Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. num_frames (`int`, *optional*): Number of frames to sample uniformly. If not passed, the whole video is loaded. video_load_backend (`str`, *optional*, defaults to `"pyav"`): @@ -415,13 +411,26 @@ class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False): return np.linspace(start_idx, end_idx, num_frames, dtype=int) """ - tokenize: Optional[bool] = False - return_dict: Optional[bool] = False num_frames: Optional[int] = None video_load_backend: Optional[str] = "pyav" video_fps: Optional[int] = None sampling_rate: Optional[int] = 16_000 sample_indices_fn: Optional[Callable] = None + load_audio_from_video: Optional[bool] = False + + +class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False): + """ + Keyword arguments for processor's `apply_chat_template`. + + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + return_dict (`bool`, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. + """ + + tokenize: Optional[bool] = False + return_dict: Optional[bool] = False class AllKwargsForChatTemplate( @@ -1236,11 +1245,11 @@ class ProcessorMixin(PushToHubMixin): def _process_messages_for_chat_template( self, - conversation: list[list[dict[str, str]]], - batch_images: list[ImageInput], - batch_videos: list[VideoInput], - batch_video_metadata: list[list[dict[str, any]]], - **chat_template_kwargs: Unpack[AllKwargsForChatTemplate], + conversation: List[List[Dict[str, str]]], + batch_images: List[ImageInput], + batch_videos: List[VideoInput], + batch_video_metadata: List[List[Dict[str, any]]], + **mm_load_kwargs: Unpack[ChatTemplateLoadKwargs], ): """ Used within `apply_chat_template` when a model has a special way to process conversation history. For example, @@ -1311,18 +1320,18 @@ class ProcessorMixin(PushToHubMixin): ) # Fill two sets of kwargs that should be used by tokenizer's `apply_chat_template` - # and for multimodal chat template + # and for multimodal data loading. Everything else will be used in `__call__` tokenizer_template_kwargs = {} for tokenizer_key in TokenizerChatTemplateKwargs.__annotations__.keys(): - tokenizer_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None) - value = kwargs.pop(tokenizer_key, tokenizer_value) + default_value = getattr(TokenizerChatTemplateKwargs, tokenizer_key, None) + value = kwargs.pop(tokenizer_key, default_value) tokenizer_template_kwargs[tokenizer_key] = value - chat_template_kwargs = {} - for key in ProcessorChatTemplateKwargs.__annotations__.keys(): - processor_value = getattr(ProcessorChatTemplateKwargs, key, None) - value = kwargs.pop(key, processor_value) - chat_template_kwargs[key] = value + mm_load_kwargs = {} + for mm_load_key in ChatTemplateLoadKwargs.__annotations__.keys(): + default_value = getattr(ChatTemplateLoadKwargs, mm_load_key, None) + value = kwargs.pop(mm_load_key, default_value) + mm_load_kwargs[mm_load_key] = value if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") @@ -1333,13 +1342,8 @@ class ProcessorMixin(PushToHubMixin): is_batched = False conversations = [conversation] - num_frames = chat_template_kwargs.get("num_frames") - video_fps = chat_template_kwargs.get("video_fps") - video_load_backend = chat_template_kwargs.get("video_load_backend") - tokenize = chat_template_kwargs.get("tokenize") - return_dict = chat_template_kwargs.get("return_dict") - sample_indices_fn = chat_template_kwargs.get("sample_indices_fn") - sampling_rate = chat_template_kwargs.pop("sampling_rate") + tokenize = kwargs.pop("tokenize", False) + return_dict = kwargs.pop("return_dict", False) if tokenize: batch_images, batch_videos = [], [] @@ -1369,31 +1373,37 @@ class ProcessorMixin(PushToHubMixin): if key in vision_info and vision_info["type"] == "video" ] - # Audio models do not accept nested list of audios (yet!) - for fname in audio_fnames: - batch_audios.append(load_audio(fname, sampling_rate=sampling_rate)) for fname in image_fnames: images.append(load_image(fname)) - for fname in video_fnames: - if isinstance(fname, (list, tuple)) and isinstance(fname[0], str): - video = [np.array(load_image(image_fname)).T for image_fname in fname] - # create a 4D video because `load_video` always returns a 4D array - video = np.stack(video) - metadata = None - logger.warning( - "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " - "If you model applies special processing based on metadata, please load the whole video and let the model sample frames." - ) - else: - video, metadata = load_video( - fname, - num_frames=num_frames, - fps=video_fps, - backend=video_load_backend, - sample_indices_fn=sample_indices_fn, - ) - videos.append(video) - video_metadata.append(metadata) + + # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list + if not mm_load_kwargs["load_audio_from_video"]: + for fname in audio_fnames: + batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) + else: + for fname in video_fnames: + if isinstance(fname, (list, tuple)) and isinstance(fname[0], str): + video = [np.array(load_image(image_fname)).T for image_fname in fname] + # create a 4D video because `load_video` always returns a 4D array + video = np.stack(video) + metadata = None + audios = None + logger.warning( + "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " + "If your model uses this metadata during processing, please load the whole video and let the model sample frames instead." + ) + else: + video, metadata = load_video( + fname, + num_frames=mm_load_kwargs["num_frames"], + fps=mm_load_kwargs["video_fps"], + backend=mm_load_kwargs["video_load_backend"], + sample_indices_fn=mm_load_kwargs["sample_indices_fn"], + ) + audios = load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]) + batch_audios.append(audios) + videos.append(video) + video_metadata.append(metadata) # Currently all processors can accept nested list of batches, but not flat list of visuals # So we'll make a batched list of images and let the processor handle it @@ -1409,7 +1419,7 @@ class ProcessorMixin(PushToHubMixin): batch_images=batch_images, batch_videos=batch_videos, batch_video_metadata=batch_video_metadata, - **chat_template_kwargs, + **mm_load_kwargs, ) prompt = self.tokenizer.apply_chat_template( @@ -1438,7 +1448,7 @@ class ProcessorMixin(PushToHubMixin): text=prompt, images=batch_images if batch_images else None, videos=batch_videos if batch_videos else None, - audios=batch_audios if batch_audios else None, + audio=batch_audios if batch_audios else None, **kwargs, ) if return_dict: diff --git a/tests/test_processing_common.py b/tests/test_processing_common.py index 0485fdf97d..f4e1e1b543 100644 --- a/tests/test_processing_common.py +++ b/tests/test_processing_common.py @@ -1097,10 +1097,7 @@ class ProcessorTesterMixin: { "role": "user", "content": [ - { - "type": "video", - "path": video_file_path, - }, + {"type": "video", "path": video_file_path}, {"type": "text", "text": "What is shown in this video?"}, ], }, @@ -1189,6 +1186,70 @@ class ProcessorTesterMixin: self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243) + @require_librosa + @require_av + def test_audio_chat_template_from_video(self): + processor = self.get_processor() + if processor.chat_template is None: + self.skipTest("Processor has no chat template") + + signature = inspect.signature(processor.__call__) + if "videos" not in {*signature.parameters.keys()} or ( + signature.parameters.get("videos") is not None + and signature.parameters["videos"].annotation == inspect._empty + ): + self.skipTest(f"{self.processor_class} does not suport video inputs") + + if "feature_extractor" not in self.processor_class.attributes: + self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") + + video_file_path = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset" + ) + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "path": video_file_path}, + {"type": "text", "text": "Which of these animals is making the sound?"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is a cow."}], + }, + { + "role": "user", + "content": [ + { + "type": "audio", + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3", + }, + {"type": "text", "text": "Is it the same sound?"}, + ], + }, + ] + + formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False) + self.assertEqual(len(formatted_prompt), 1) # batch size=1 + + out_dict = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="np", + load_audio_from_video=True, + ) + self.assertTrue(self.audio_input_name in out_dict) + self.assertTrue(self.video_input_name in out_dict) + + # should always have input_ids and attention_mask + self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1 + self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1 + self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation + self.assertEqual(len(out_dict[self.video_input_name]), 1) # 1 video in the conversation + @require_librosa def test_audio_chat_template_single(self): processor = self.get_processor()