Support return_tensors in audio chat templates (#34601)

* add audio chat templates

* update

* update

* nit

* green ci

* we dont care about the order anymore

* clean up after rebase

* overriden tests rename

* rename shieldgemma also

* one more rename

* require_read_token

* removde images/videos

* retrigger CI flaky
This commit is contained in:
Raushan Turganbay
2025-03-25 11:08:47 +01:00
committed by GitHub
parent 19085c28da
commit 0f733110a6
11 changed files with 403 additions and 160 deletions

View File

@@ -16,10 +16,52 @@ Audio processing functions to extract features from audio waveforms. This code i
and remove unnecessary dependencies. and remove unnecessary dependencies.
""" """
import os
import warnings import warnings
from io import BytesIO
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import requests
from .utils import is_librosa_available, requires_backends
if is_librosa_available():
import librosa
def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
"""
Loads `audio` to an np.ndarray object.
Args:
audio (`str` or `np.ndarray`):
The audio to be laoded to the numpy array format.
sampling_rate (`int`, *optional*, defaults to 16000):
The samlping rate to be used when loading the audio. It should be same as the
sampling rate the model you will be using further was trained with.
timeout (`float`, *optional*):
The timeout value in seconds for the URL request.
Returns:
`np.ndarray`: A numpy artay representing the audio.
"""
requires_backends(load_audio, ["librosa"])
if isinstance(audio, str):
# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
if audio.startswith("http://") or audio.startswith("https://"):
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
elif os.path.isfile(audio):
audio = librosa.load(audio, sr=sampling_rate)[0]
elif isinstance(audio, np.ndarray):
audio = audio
else:
raise TypeError(
"Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
)
return audio
AudioInput = Union[ AudioInput = Union[

View File

@@ -16,13 +16,24 @@
Processor class for Qwen2Audio. Processor class for Qwen2Audio.
""" """
from typing import List, Optional, Union import warnings
from typing import List, Union
import numpy as np import numpy as np
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg
class Qwen2AudioProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"audio_kwargs": {},
}
class Qwen2AudioProcessor(ProcessorMixin): class Qwen2AudioProcessor(ProcessorMixin):
@@ -49,6 +60,7 @@ class Qwen2AudioProcessor(ProcessorMixin):
""" """
attributes = ["feature_extractor", "tokenizer"] attributes = ["feature_extractor", "tokenizer"]
valid_kwargs = ["chat_template", "audio_token", "audio_bos_token", "audio_eos_token"]
feature_extractor_class = "WhisperFeatureExtractor" feature_extractor_class = "WhisperFeatureExtractor"
tokenizer_class = "AutoTokenizer" tokenizer_class = "AutoTokenizer"
@@ -68,13 +80,13 @@ class Qwen2AudioProcessor(ProcessorMixin):
self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token self.audio_eos_token = tokenizer.audio_eos_token if hasattr(tokenizer, "audio_eos_token") else audio_eos_token
super().__init__(feature_extractor, tokenizer, chat_template=chat_template) super().__init__(feature_extractor, tokenizer, chat_template=chat_template)
@deprecate_kwarg("audios", version="4.54.0", new_name="audio")
def __call__( def __call__(
self, self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
audios: Union[np.ndarray, List[np.ndarray]] = None, audio: Union[np.ndarray, List[np.ndarray]] = None,
padding: Union[bool, str, PaddingStrategy] = False, audios=None, # kept for BC
sampling_rate: Optional[int] = None, **kwargs: Unpack[Qwen2AudioProcessorKwargs],
**kwargs,
) -> BatchFeature: ) -> BatchFeature:
""" """
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text` Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
@@ -88,43 +100,48 @@ class Qwen2AudioProcessor(ProcessorMixin):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audios (`np.ndarray`, `List[np.ndarray]`): audio (`np.ndarray`, `List[np.ndarray]`):
The audio or batch of audios to be prepared. Each audio can be a NumPy array. The audio or batch of audios to be prepared. Each audio can be a NumPy array.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
sampling_rate (`int`, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
""" """
# Handle BC when user passes deprecared keyword argument
if audios is not None and audio is None:
audio = audios
warnings.wanr(
"You may have used the keyword argument for the `audio` inputs. It is strongly recommended to pass inputs with keyword arguments "
"with keys `audio` and `text`. From transformers v4.55 `audio` will be the onle acceptable keyword argument.",
FutureWarning,
)
if text is None: if text is None:
raise ValueError("You need to specify either a `text` input to process.") raise ValueError("You need to specify `text` input to process.")
elif isinstance(text, str): elif isinstance(text, str):
text = [text] text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str): elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings") raise ValueError("Invalid input text. Please provide a string, or a list of strings")
output_kwargs = self._merge_kwargs(
Qwen2AudioProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if audio is not None:
# ensure we have as much audios as audio tokens # ensure we have as much audios as audio tokens
num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) num_audio_tokens = sum(sample.count(self.audio_token) for sample in text)
num_audios = 1 if isinstance(audios, np.ndarray) else len(audios) num_audios = 1 if type(audio) == np.ndarray else len(audio)
if num_audio_tokens != num_audios: if num_audio_tokens != num_audios:
raise ValueError( raise ValueError(
f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}"
) )
if audios is not None: # Some kwargs should not be changed so we can expand text with audio tokens below
audio_inputs = self.feature_extractor( output_kwargs["audio_kwargs"]["return_attention_mask"] = True
audios, sampling_rate=sampling_rate, return_attention_mask=True, padding="max_length", **kwargs output_kwargs["audio_kwargs"]["padding"] = "max_length"
) audio_inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
audio_inputs["feature_attention_mask"] = audio_inputs.pop(
"attention_mask" # rename attention_mask to prevent conflicts later on
) # rename attention_mask to prevent conflicts later on audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask")
expanded_text = [] expanded_text = []
audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist() audio_lengths = audio_inputs["feature_attention_mask"].sum(-1).tolist()
@@ -162,9 +179,9 @@ class Qwen2AudioProcessor(ProcessorMixin):
expanded_text.append(sample) expanded_text.append(sample)
text = expanded_text text = expanded_text
inputs = self.tokenizer(text, padding=padding, **kwargs) inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
if audios is not None: if audio is not None:
inputs.update(audio_inputs) inputs.update(audio_inputs)
return BatchFeature(data={**inputs}) return BatchFeature(data={**inputs})
@@ -190,6 +207,7 @@ class Qwen2AudioProcessor(ProcessorMixin):
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names + ["feature_attention_mask"])) return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names + ["feature_attention_mask"]))
@property @property
# NOTE: we don't have default templates anymore, and the below is kept only because the hub config is not yet updated!
def default_chat_template(self): def default_chat_template(self):
""" """
This default vicuna template formats inputs in the form of a chat history. For each message in the chat history: This default vicuna template formats inputs in the form of a chat history. For each message in the chat history:
@@ -228,7 +246,7 @@ class Qwen2AudioProcessor(ProcessorMixin):
"{{ message['content'] }}<|im_end|>\n" "{{ message['content'] }}<|im_end|>\n"
"{% else %}" "{% else %}"
"{% for content in message['content'] %}" "{% for content in message['content'] %}"
"{% if 'audio' in content or 'audio_url' in content %}" "{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}"
"{% set audio_count.value = audio_count.value + 1 %}" "{% set audio_count.value = audio_count.value + 1 %}"
"Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n" "Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n"
"{% elif 'text' in content %}" "{% elif 'text' in content %}"

View File

@@ -28,6 +28,7 @@ from typing import Any, Callable, Optional, TypedDict, Union
import numpy as np import numpy as np
import typing_extensions import typing_extensions
from .audio_utils import load_audio
from .dynamic_module_utils import custom_object_save from .dynamic_module_utils import custom_object_save
from .image_utils import ( from .image_utils import (
ChannelDimension, ChannelDimension,
@@ -419,6 +420,7 @@ class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False):
num_frames: Optional[int] = None num_frames: Optional[int] = None
video_load_backend: Optional[str] = "pyav" video_load_backend: Optional[str] = "pyav"
video_fps: Optional[int] = None video_fps: Optional[int] = None
sampling_rate: Optional[int] = 16_000
sample_indices_fn: Optional[Callable] = None sample_indices_fn: Optional[Callable] = None
@@ -938,6 +940,7 @@ class ProcessorMixin(PushToHubMixin):
"common_kwargs": {}, "common_kwargs": {},
} }
possible_modality_keywords = {"text", "audio", "videos", "images"}
used_keys = set() used_keys = set()
# get defaults from set model processor kwargs if they exist # get defaults from set model processor kwargs if they exist
@@ -995,7 +998,7 @@ class ProcessorMixin(PushToHubMixin):
if key not in used_keys: if key not in used_keys:
if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys(): if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__.keys():
output_kwargs["common_kwargs"][key] = kwargs[key] output_kwargs["common_kwargs"][key] = kwargs[key]
else: elif key not in possible_modality_keywords:
logger.warning_once( logger.warning_once(
f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored." f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
) )
@@ -1336,15 +1339,23 @@ class ProcessorMixin(PushToHubMixin):
tokenize = chat_template_kwargs.get("tokenize") tokenize = chat_template_kwargs.get("tokenize")
return_dict = chat_template_kwargs.get("return_dict") return_dict = chat_template_kwargs.get("return_dict")
sample_indices_fn = chat_template_kwargs.get("sample_indices_fn") sample_indices_fn = chat_template_kwargs.get("sample_indices_fn")
sampling_rate = chat_template_kwargs.pop("sampling_rate")
if tokenize: if tokenize:
batch_images, batch_videos = [], [] batch_images, batch_videos = [], []
batch_audios = []
batch_video_metadata = [] batch_video_metadata = []
for conversation in conversations: for conversation in conversations:
images, videos = [], [] images, videos = [], []
video_metadata = [] video_metadata = []
for message in conversation: for message in conversation:
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
audio_fnames = [
content[key]
for content in message["content"]
for key in ["audio", "url", "path"]
if key in content and content["type"] == "audio"
]
image_fnames = [ image_fnames = [
vision_info[key] vision_info[key]
for vision_info in visuals for vision_info in visuals
@@ -1357,6 +1368,10 @@ class ProcessorMixin(PushToHubMixin):
for key in ["video", "url", "path"] for key in ["video", "url", "path"]
if key in vision_info and vision_info["type"] == "video" if key in vision_info and vision_info["type"] == "video"
] ]
# Audio models do not accept nested list of audios (yet!)
for fname in audio_fnames:
batch_audios.append(load_audio(fname, sampling_rate=sampling_rate))
for fname in image_fnames: for fname in image_fnames:
images.append(load_image(fname)) images.append(load_image(fname))
for fname in video_fnames: for fname in video_fnames:
@@ -1423,6 +1438,7 @@ class ProcessorMixin(PushToHubMixin):
text=prompt, text=prompt,
images=batch_images if batch_images else None, images=batch_images if batch_images else None,
videos=batch_videos if batch_videos else None, videos=batch_videos if batch_videos else None,
audios=batch_audios if batch_audios else None,
**kwargs, **kwargs,
) )
if return_dict: if return_dict:

View File

@@ -238,7 +238,7 @@ And who is that?<|im_end|>
self.assertEqual(rendered, expected_rendered) self.assertEqual(rendered, expected_rendered)
# Override as AriaImageProcessor doesn't accept `do_rescale` # Override as AriaImageProcessor doesn't accept `do_rescale`
def test_chat_template_accepts_processing_kwargs(self): def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")

View File

@@ -116,7 +116,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(list(inputs.keys()), processor.model_input_names) self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_chat_template_single(self): def test_image_chat_template_single(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")
@@ -154,7 +154,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(len(out_dict["attention_mask"]), 1) self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 71280) self.assertEqual(len(out_dict[self.images_input_name]), 71280)
def test_chat_template_batched(self): def test_image_chat_template_batched(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")

View File

@@ -11,20 +11,63 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Optional
from transformers import AutoProcessor, AutoTokenizer, Qwen2AudioProcessor, WhisperFeatureExtractor from transformers import AutoProcessor, AutoTokenizer, Qwen2AudioProcessor, WhisperFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import require_torch, require_torchaudio
from transformers.utils import is_torch_available
from ...test_processing_common import ProcessorTesterMixin
if is_torch_available:
pass
@require_torch @require_torch
@require_torchaudio @require_torchaudio
class Qwen2AudioProcessorTest(unittest.TestCase): class Qwen2AudioProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Qwen2AudioProcessor
def setUp(self): def setUp(self):
self.checkpoint = "Qwen/Qwen2-Audio-7B-Instruct" self.checkpoint = "Qwen/Qwen2-Audio-7B-Instruct"
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
processor_kwargs = self.prepare_processor_dict()
processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_audio_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).audio_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_processor_dict(self):
return {
"chat_template": "{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}{% set audio_count.value = audio_count.value + 1 %}Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
}
# Override as Qwen2AudioProcessor needs audio tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <|AUDIO|>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <|AUDIO|>"]
return ["lower newer <|AUDIO|>", "<|AUDIO|> upper older longer string"] + ["<|AUDIO|> lower newer"] * (
batch_size - 2
)
def test_can_load_various_tokenizers(self): def test_can_load_various_tokenizers(self):
processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint) processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint)
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint) tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
@@ -77,7 +120,7 @@ class Qwen2AudioProcessorTest(unittest.TestCase):
"assistant", "assistant",
"Ċ", "Ċ",
] ]
print(slow_tokenizer.tokenize(prompt))
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT) self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
@@ -110,5 +153,31 @@ class Qwen2AudioProcessorTest(unittest.TestCase):
}, },
] ]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt) self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_with_continue_final_message(self):
processor = AutoProcessor.from_pretrained(self.checkpoint)
expected_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of " # fmt: skip
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of "}],
},
]
prompt = processor.apply_chat_template(messages, continue_final_message=True)
self.assertEqual(expected_prompt, prompt)

View File

@@ -113,7 +113,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(list(inputs.keys()), processor.model_input_names) self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_chat_template_single(self): def test_image_chat_template_single(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")
@@ -151,7 +151,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(len(out_dict["attention_mask"]), 1) self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 71280) self.assertEqual(len(out_dict[self.images_input_name]), 71280)
def test_chat_template_batched(self): def test_image_chat_template_batched(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")

View File

@@ -166,22 +166,22 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# TODO(ryanmullins): Adapt this test for ShieldGemma 2 # TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_accepts_processing_kwargs(self): def test_image_chat_template_accepts_processing_kwargs(self):
pass pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2 # TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_batched(self): def test_image_chat_template_batched(self):
pass pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2 # TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_dict_torch(self): def test_image_chat_template_dict_torch(self):
pass pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2 # TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.") @unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_single(self): def test_image_chat_template_single(self):
pass pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2 # TODO(ryanmullins): Adapt this test for ShieldGemma 2

View File

@@ -18,8 +18,6 @@ import shutil
import tempfile import tempfile
import unittest import unittest
import numpy as np
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.utils import FEATURE_EXTRACTOR_NAME from transformers.utils import FEATURE_EXTRACTOR_NAME
@@ -30,6 +28,8 @@ from .test_feature_extraction_wav2vec2 import floats_list
class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase): class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Wav2Vec2Processor processor_class = Wav2Vec2Processor
audio_input_name = "input_values"
text_input_name = "labels"
def setUp(self): def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ") vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
@@ -132,22 +132,6 @@ class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
for key in encoded_tok.keys(): for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key]) self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_padding_argument_not_ignored(self):
# padding, or any other overlap arg between audio extractor and tokenizer
# should be passed to both text and audio and not ignored
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
# padding = True should not raise an error and will if the audio processor popped its value to None
_ = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
def test_tokenizer_decode(self): def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()

View File

@@ -18,8 +18,6 @@ import shutil
import tempfile import tempfile
import unittest import unittest
import numpy as np
from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
@@ -32,6 +30,7 @@ from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase): class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Wav2Vec2BertProcessor processor_class = Wav2Vec2BertProcessor
text_input_name = "labels"
def setUp(self): def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ") vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
@@ -136,22 +135,6 @@ class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
for key in encoded_tok.keys(): for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key]) self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_padding_argument_not_ignored(self):
# padding, or any other overlap arg between audio extractor and tokenizer
# should be passed to both text and audio and not ignored
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
# padding = True should not raise an error and will if the audio processor popped its value to None
# processor(input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt")
_ = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
def test_tokenizer_decode(self): def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor() feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()

View File

@@ -29,6 +29,7 @@ from transformers.processing_utils import Unpack
from transformers.testing_utils import ( from transformers.testing_utils import (
check_json_file_has_correct_format, check_json_file_has_correct_format,
require_av, require_av,
require_librosa,
require_torch, require_torch,
require_vision, require_vision,
) )
@@ -73,6 +74,7 @@ class ProcessorTesterMixin:
text_input_name = "input_ids" text_input_name = "input_ids"
images_input_name = "pixel_values" images_input_name = "pixel_values"
videos_input_name = "pixel_values_videos" videos_input_name = "pixel_values_videos"
audio_input_name = "input_features"
def prepare_processor_dict(self): def prepare_processor_dict(self):
return {} return {}
@@ -105,6 +107,8 @@ class ProcessorTesterMixin:
processor = self.processor_class(**components, **self.prepare_processor_dict()) processor = self.processor_class(**components, **self.prepare_processor_dict())
return processor return processor
# TODO: raushan unify all these special token LLMs under the general preparation. We can get audio/image token
# from tokenizer, so we can generalize instead of overriding
def prepare_text_inputs(self, batch_size: Optional[int] = None): def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None: if batch_size is None:
return "lower newer" return "lower newer"
@@ -368,96 +372,78 @@ class ProcessorTesterMixin:
def test_tokenizer_defaults_preserved_by_kwargs_audio(self): def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes: if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor") feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"): tokenizer = self.get_component("tokenizer", max_length=300, padding="max_length")
tokenizer = self.get_tokenizer(max_length=117, padding="max_length") processor_kwargs = self.prepare_processor_dict()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length") processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
else:
self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000)) raw_speech = floats_list((3, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech]
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt") inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
if "input_ids" in inputs: self.assertEqual(len(inputs[self.text_input_name][0]), 300)
self.assertEqual(len(inputs["input_ids"][0]), 117)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 117)
@require_torch @require_torch
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self): def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes: if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor") feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117) tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token: processor_kwargs = self.prepare_processor_dict()
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000)) raw_speech = floats_list((3, 1000))
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length") raw_speech = [np.asarray(audio) for audio in raw_speech]
if "input_ids" in inputs: inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=300, padding="max_length")
self.assertEqual(len(inputs["input_ids"][0]), 112)
elif "labels" in inputs: self.assertEqual(len(inputs[self.text_input_name][0]), 300)
self.assertEqual(len(inputs["labels"][0]), 112)
@require_torch @require_torch
def test_unstructured_kwargs_audio(self): def test_unstructured_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes: if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor") feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"): tokenizer = self.get_component("tokenizer")
tokenizer = self.get_tokenizer(max_length=117) processor_kwargs = self.prepare_processor_dict()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117) processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer" input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000)) raw_speech = floats_list((3, 1000))
inputs = processor( raw_speech = [np.asarray(audio) for audio in raw_speech]
text=input_str, inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=300, padding="max_length")
audio=raw_speech,
return_tensors="pt",
padding="max_length",
max_length=76,
)
if "input_ids" in inputs: self.assertEqual(len(inputs[self.text_input_name][0]), 300)
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
@require_torch @require_torch
def test_doubly_passed_kwargs_audio(self): def test_doubly_passed_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes: if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor") feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer") tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token: processor_kwargs = self.prepare_processor_dict()
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"] input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000)) raw_speech = floats_list((3, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = processor( _ = processor(
text=input_str, text=input_str,
audio=raw_speech, audio=raw_speech,
audio_kwargs={"padding": "max_length"}, text_kwargs={"padding": "max_length"},
padding="max_length", padding="max_length",
) )
@@ -466,31 +452,27 @@ class ProcessorTesterMixin:
def test_structured_kwargs_audio_nested(self): def test_structured_kwargs_audio_nested(self):
if "feature_extractor" not in self.processor_class.attributes: if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}") self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor") feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"): tokenizer = self.get_component("tokenizer", max_length=117)
tokenizer = self.get_tokenizer() processor_kwargs = self.prepare_processor_dict()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer") processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"] input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000)) raw_speech = floats_list((3, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech]
# Define the kwargs for each modality # Define the kwargs for each modality
all_kwargs = { all_kwargs = {
"common_kwargs": {"return_tensors": "pt"}, "common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76}, "text_kwargs": {"padding": "max_length", "max_length": 76},
"audio_kwargs": {"padding": "max_length", "max_length": 66}, "audio_kwargs": {"padding": "max_length", "max_length": 300},
} }
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs) inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
if "input_ids" in inputs: self.assertEqual(len(inputs[self.text_input_name][0]), 76)
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
def test_tokenizer_defaults_preserved_by_kwargs_video(self): def test_tokenizer_defaults_preserved_by_kwargs_video(self):
if "video_processor" not in self.processor_class.attributes: if "video_processor" not in self.processor_class.attributes:
@@ -680,9 +662,10 @@ class ProcessorTesterMixin:
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs # TODO: the same test, but for audio + text processors that have strong overlap in kwargs
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication # TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
def test_overlapping_text_kwargs_handling(self): def test_overlapping_text_image_kwargs_handling(self):
if "image_processor" not in self.processor_class.attributes: if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}") self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components() processor_components = self.prepare_components()
processor = self.processor_class(**processor_components) processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor) self.skip_processor_without_typed_kwargs(processor)
@@ -699,6 +682,28 @@ class ProcessorTesterMixin:
text_kwargs={"padding": "do_not_pad"}, text_kwargs={"padding": "do_not_pad"},
) )
def test_overlapping_text_audio_kwargs_handling(self):
"""
Checks that `padding`, or any other overlap arg between audio extractor and tokenizer
is be passed to only text and ignored for audio for BC purposes
"""
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
tokenizer = self.get_component("tokenizer")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs(batch_size=3)
audio_lengths = [4000, 8000, 16000, 32000]
raw_speech = [np.asarray(audio)[:length] for audio, length in zip(floats_list((3, 32_000)), audio_lengths)]
# padding = True should not raise an error and will if the audio processor popped its value to None
_ = processor(text=input_str, audio=raw_speech, padding=True, return_tensors="pt")
def test_prepare_and_validate_optional_call_args(self): def test_prepare_and_validate_optional_call_args(self):
processor = self.get_processor() processor = self.get_processor()
optional_call_args_name = getattr(processor, "optional_call_args", []) optional_call_args_name = getattr(processor, "optional_call_args", [])
@@ -752,11 +757,14 @@ class ProcessorTesterMixin:
# the reloaded tokenizer should get the chat template as well # the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template) self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
def test_chat_template_single(self): def test_image_chat_template_single(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [ messages = [
[ [
{ {
@@ -797,11 +805,14 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict["attention_mask"]), 1) self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 1) self.assertEqual(len(out_dict[self.images_input_name]), 1)
def test_chat_template_batched(self): def test_image_chat_template_batched(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
batched_messages = [ batched_messages = [
[ [
{ {
@@ -864,11 +875,14 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict["attention_mask"]), 2) self.assertEqual(len(out_dict["attention_mask"]), 2)
self.assertEqual(len(out_dict[self.images_input_name]), 2) self.assertEqual(len(out_dict[self.images_input_name]), 2)
def test_chat_template_accepts_processing_kwargs(self): def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [ messages = [
[ [
{ {
@@ -915,11 +929,14 @@ class ProcessorTesterMixin:
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0) self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
@require_torch @require_torch
def test_chat_template_dict_torch(self): def test_image_chat_template_dict_torch(self):
processor = self.get_processor() processor = self.get_processor()
if processor.chat_template is None: if processor.chat_template is None:
self.skipTest("Processor has no chat template") self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [ messages = [
{ {
"role": "user", "role": "user",
@@ -1171,3 +1188,117 @@ class ProcessorTesterMixin:
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text) self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1) self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243) self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
@require_librosa
def test_audio_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
},
{"type": "text", "text": "How about this one?"},
],
},
]
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1) # batch size=1
formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
messages[1]["content"][0]["audio"] = (
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
)
messages[3]["content"][0]["audio"] = (
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.audio_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
@require_torch
@require_librosa
def test_audio_chat_template_dict_torch(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav",
},
{"type": "text", "text": "How about this one?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.audio_input_name in out_dict_tensors)
for k in out_dict_tensors:
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)