Support return_tensors in audio chat templates (#34601)

* add audio chat templates

* update

* update

* nit

* green ci

* we dont care about the order anymore

* clean up after rebase

* overriden tests rename

* rename shieldgemma also

* one more rename

* require_read_token

* removde images/videos

* retrigger CI flaky
This commit is contained in:
Raushan Turganbay
2025-03-25 11:08:47 +01:00
committed by GitHub
parent 19085c28da
commit 0f733110a6
11 changed files with 403 additions and 160 deletions

View File

@@ -238,7 +238,7 @@ And who is that?<|im_end|>
self.assertEqual(rendered, expected_rendered)
# Override as AriaImageProcessor doesn't accept `do_rescale`
def test_chat_template_accepts_processing_kwargs(self):
def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")

View File

@@ -116,7 +116,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_chat_template_single(self):
def test_image_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
@@ -154,7 +154,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
def test_chat_template_batched(self):
def test_image_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")

View File

@@ -11,20 +11,63 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from typing import Optional
from transformers import AutoProcessor, AutoTokenizer, Qwen2AudioProcessor, WhisperFeatureExtractor
from transformers.testing_utils import require_torch, require_torchaudio
from transformers.utils import is_torch_available
from ...test_processing_common import ProcessorTesterMixin
if is_torch_available:
pass
@require_torch
@require_torchaudio
class Qwen2AudioProcessorTest(unittest.TestCase):
class Qwen2AudioProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Qwen2AudioProcessor
def setUp(self):
self.checkpoint = "Qwen/Qwen2-Audio-7B-Instruct"
self.tmpdirname = tempfile.mkdtemp()
processor_kwargs = self.prepare_processor_dict()
processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_audio_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).audio_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_processor_dict(self):
return {
"chat_template": "{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}{% set audio_count.value = audio_count.value + 1 %}Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
}
# Override as Qwen2AudioProcessor needs audio tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <|AUDIO|>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <|AUDIO|>"]
return ["lower newer <|AUDIO|>", "<|AUDIO|> upper older longer string"] + ["<|AUDIO|> lower newer"] * (
batch_size - 2
)
def test_can_load_various_tokenizers(self):
processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint)
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
@@ -77,7 +120,7 @@ class Qwen2AudioProcessorTest(unittest.TestCase):
"assistant",
"Ċ",
]
print(slow_tokenizer.tokenize(prompt))
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
@@ -110,5 +153,31 @@ class Qwen2AudioProcessorTest(unittest.TestCase):
},
]
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
self.assertEqual(expected_prompt, formatted_prompt)
def test_chat_template_with_continue_final_message(self):
processor = AutoProcessor.from_pretrained(self.checkpoint)
expected_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of " # fmt: skip
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of "}],
},
]
prompt = processor.apply_chat_template(messages, continue_final_message=True)
self.assertEqual(expected_prompt, prompt)

View File

@@ -113,7 +113,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_chat_template_single(self):
def test_image_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
@@ -151,7 +151,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
def test_chat_template_batched(self):
def test_image_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")

View File

@@ -166,22 +166,22 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_accepts_processing_kwargs(self):
def test_image_chat_template_accepts_processing_kwargs(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_batched(self):
def test_image_chat_template_batched(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_dict_torch(self):
def test_image_chat_template_dict_torch(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
def test_chat_template_single(self):
def test_image_chat_template_single(self):
pass
# TODO(ryanmullins): Adapt this test for ShieldGemma 2

View File

@@ -18,8 +18,6 @@ import shutil
import tempfile
import unittest
import numpy as np
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
from transformers.utils import FEATURE_EXTRACTOR_NAME
@@ -30,6 +28,8 @@ from .test_feature_extraction_wav2vec2 import floats_list
class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Wav2Vec2Processor
audio_input_name = "input_values"
text_input_name = "labels"
def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
@@ -132,22 +132,6 @@ class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_padding_argument_not_ignored(self):
# padding, or any other overlap arg between audio extractor and tokenizer
# should be passed to both text and audio and not ignored
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
# padding = True should not raise an error and will if the audio processor popped its value to None
_ = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()

View File

@@ -18,8 +18,6 @@ import shutil
import tempfile
import unittest
import numpy as np
from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
@@ -32,6 +30,7 @@ from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Wav2Vec2BertProcessor
text_input_name = "labels"
def setUp(self):
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
@@ -136,22 +135,6 @@ class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_padding_argument_not_ignored(self):
# padding, or any other overlap arg between audio extractor and tokenizer
# should be passed to both text and audio and not ignored
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
batch_duration_in_seconds = [1, 3, 2, 6]
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
# padding = True should not raise an error and will if the audio processor popped its value to None
# processor(input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt")
_ = processor(
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
)
def test_tokenizer_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()

View File

@@ -29,6 +29,7 @@ from transformers.processing_utils import Unpack
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_av,
require_librosa,
require_torch,
require_vision,
)
@@ -73,6 +74,7 @@ class ProcessorTesterMixin:
text_input_name = "input_ids"
images_input_name = "pixel_values"
videos_input_name = "pixel_values_videos"
audio_input_name = "input_features"
def prepare_processor_dict(self):
return {}
@@ -105,6 +107,8 @@ class ProcessorTesterMixin:
processor = self.processor_class(**components, **self.prepare_processor_dict())
return processor
# TODO: raushan unify all these special token LLMs under the general preparation. We can get audio/image token
# from tokenizer, so we can generalize instead of overriding
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer"
@@ -363,101 +367,83 @@ class ProcessorTesterMixin:
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
# text + audio kwargs testing
# text + audio kwargs testing
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117, padding="max_length")
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
else:
self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
tokenizer = self.get_component("tokenizer", max_length=300, padding="max_length")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech]
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 117)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 117)
self.assertEqual(len(inputs[self.text_input_name][0]), 300)
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
tokenizer = self.get_component("tokenizer", max_length=117)
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000))
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 112)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 112)
raw_speech = [np.asarray(audio) for audio in raw_speech]
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=300, padding="max_length")
self.assertEqual(len(inputs[self.text_input_name][0]), 300)
@require_torch
def test_unstructured_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer(max_length=117)
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer", max_length=117)
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
tokenizer = self.get_component("tokenizer")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000))
inputs = processor(
text=input_str,
audio=raw_speech,
return_tensors="pt",
padding="max_length",
max_length=76,
)
raw_speech = [np.asarray(audio) for audio in raw_speech]
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=300, padding="max_length")
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
self.assertEqual(len(inputs[self.text_input_name][0]), 300)
@require_torch
def test_doubly_passed_kwargs_audio(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
tokenizer = self.get_component("tokenizer")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech]
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
audio=raw_speech,
audio_kwargs={"padding": "max_length"},
text_kwargs={"padding": "max_length"},
padding="max_length",
)
@@ -466,31 +452,27 @@ class ProcessorTesterMixin:
def test_structured_kwargs_audio_nested(self):
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
if hasattr(self, "get_tokenizer"):
tokenizer = self.get_tokenizer()
elif hasattr(self, "get_component"):
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
tokenizer = self.get_component("tokenizer", max_length=117)
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = ["lower newer"]
input_str = self.prepare_text_inputs(batch_size=3)
raw_speech = floats_list((3, 1000))
raw_speech = [np.asarray(audio) for audio in raw_speech]
# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"text_kwargs": {"padding": "max_length", "max_length": 76},
"audio_kwargs": {"padding": "max_length", "max_length": 66},
"audio_kwargs": {"padding": "max_length", "max_length": 300},
}
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
if "input_ids" in inputs:
self.assertEqual(len(inputs["input_ids"][0]), 76)
elif "labels" in inputs:
self.assertEqual(len(inputs["labels"][0]), 76)
self.assertEqual(len(inputs[self.text_input_name][0]), 76)
def test_tokenizer_defaults_preserved_by_kwargs_video(self):
if "video_processor" not in self.processor_class.attributes:
@@ -680,9 +662,10 @@ class ProcessorTesterMixin:
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
def test_overlapping_text_kwargs_handling(self):
def test_overlapping_text_image_kwargs_handling(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
processor_components = self.prepare_components()
processor = self.processor_class(**processor_components)
self.skip_processor_without_typed_kwargs(processor)
@@ -699,6 +682,28 @@ class ProcessorTesterMixin:
text_kwargs={"padding": "do_not_pad"},
)
def test_overlapping_text_audio_kwargs_handling(self):
"""
Checks that `padding`, or any other overlap arg between audio extractor and tokenizer
is be passed to only text and ignored for audio for BC purposes
"""
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
feature_extractor = self.get_component("feature_extractor")
tokenizer = self.get_component("tokenizer")
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
self.skip_processor_without_typed_kwargs(processor)
input_str = self.prepare_text_inputs(batch_size=3)
audio_lengths = [4000, 8000, 16000, 32000]
raw_speech = [np.asarray(audio)[:length] for audio, length in zip(floats_list((3, 32_000)), audio_lengths)]
# padding = True should not raise an error and will if the audio processor popped its value to None
_ = processor(text=input_str, audio=raw_speech, padding=True, return_tensors="pt")
def test_prepare_and_validate_optional_call_args(self):
processor = self.get_processor()
optional_call_args_name = getattr(processor, "optional_call_args", [])
@@ -752,11 +757,14 @@ class ProcessorTesterMixin:
# the reloaded tokenizer should get the chat template as well
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
def test_chat_template_single(self):
def test_image_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [
[
{
@@ -797,11 +805,14 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict["attention_mask"]), 1)
self.assertEqual(len(out_dict[self.images_input_name]), 1)
def test_chat_template_batched(self):
def test_image_chat_template_batched(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
batched_messages = [
[
{
@@ -864,11 +875,14 @@ class ProcessorTesterMixin:
self.assertEqual(len(out_dict["attention_mask"]), 2)
self.assertEqual(len(out_dict[self.images_input_name]), 2)
def test_chat_template_accepts_processing_kwargs(self):
def test_image_chat_template_accepts_processing_kwargs(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [
[
{
@@ -915,11 +929,14 @@ class ProcessorTesterMixin:
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
@require_torch
def test_chat_template_dict_torch(self):
def test_image_chat_template_dict_torch(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
messages = [
{
"role": "user",
@@ -1171,3 +1188,117 @@ class ProcessorTesterMixin:
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
@require_librosa
def test_audio_chat_template_single(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
},
{"type": "text", "text": "How about this one?"},
],
},
]
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
self.assertEqual(len(formatted_prompt), 1) # batch size=1
formatted_prompt_tokenized = processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
)
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
self.assertListEqual(expected_output, formatted_prompt_tokenized)
messages[1]["content"][0]["audio"] = (
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
)
messages[3]["content"][0]["audio"] = (
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
)
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
self.assertTrue(self.audio_input_name in out_dict)
# should always have input_ids and attention_mask
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
@require_torch
@require_librosa
def test_audio_chat_template_dict_torch(self):
processor = self.get_processor()
if processor.chat_template is None:
self.skipTest("Processor has no chat template")
if "feature_extractor" not in self.processor_class.attributes:
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
messages = [
{
"role": "system",
"content": [{"type": "text", "text": "You are a helpful assistant."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
},
{"type": "text", "text": "What's that sound?"},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
},
{
"role": "user",
"content": [
{
"type": "audio",
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav",
},
{"type": "text", "text": "How about this one?"},
],
},
]
out_dict_tensors = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.assertTrue(self.audio_input_name in out_dict_tensors)
for k in out_dict_tensors:
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)