Support return_tensors in audio chat templates (#34601)
* add audio chat templates * update * update * nit * green ci * we dont care about the order anymore * clean up after rebase * overriden tests rename * rename shieldgemma also * one more rename * require_read_token * removde images/videos * retrigger CI flaky
This commit is contained in:
committed by
GitHub
parent
19085c28da
commit
0f733110a6
@@ -29,6 +29,7 @@ from transformers.processing_utils import Unpack
|
||||
from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
require_av,
|
||||
require_librosa,
|
||||
require_torch,
|
||||
require_vision,
|
||||
)
|
||||
@@ -73,6 +74,7 @@ class ProcessorTesterMixin:
|
||||
text_input_name = "input_ids"
|
||||
images_input_name = "pixel_values"
|
||||
videos_input_name = "pixel_values_videos"
|
||||
audio_input_name = "input_features"
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {}
|
||||
@@ -105,6 +107,8 @@ class ProcessorTesterMixin:
|
||||
processor = self.processor_class(**components, **self.prepare_processor_dict())
|
||||
return processor
|
||||
|
||||
# TODO: raushan unify all these special token LLMs under the general preparation. We can get audio/image token
|
||||
# from tokenizer, so we can generalize instead of overriding
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
if batch_size is None:
|
||||
return "lower newer"
|
||||
@@ -363,101 +367,83 @@ class ProcessorTesterMixin:
|
||||
self.assertLessEqual(inputs[self.images_input_name][0][0].mean(), 0)
|
||||
self.assertEqual(inputs[self.text_input_name].shape[-1], 76)
|
||||
|
||||
# text + audio kwargs testing
|
||||
# text + audio kwargs testing
|
||||
@require_torch
|
||||
def test_tokenizer_defaults_preserved_by_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117, padding="max_length")
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117, padding="max_length")
|
||||
else:
|
||||
self.assertTrue(False, "Processor doesn't have get_tokenizer or get_component defined")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
tokenizer = self.get_component("tokenizer", max_length=300, padding="max_length")
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=3)
|
||||
raw_speech = floats_list((3, 1000))
|
||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 117)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 117)
|
||||
self.assertEqual(len(inputs[self.text_input_name][0]), 300)
|
||||
|
||||
@require_torch
|
||||
def test_kwargs_overrides_default_tokenizer_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
input_str = "lower newer"
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=3)
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=112, padding="max_length")
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 112)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 112)
|
||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=300, padding="max_length")
|
||||
|
||||
self.assertEqual(len(inputs[self.text_input_name][0]), 300)
|
||||
|
||||
@require_torch
|
||||
def test_unstructured_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer(max_length=117)
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = "lower newer"
|
||||
input_str = self.prepare_text_inputs(batch_size=3)
|
||||
raw_speech = floats_list((3, 1000))
|
||||
inputs = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=76,
|
||||
)
|
||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||
inputs = processor(text=input_str, audio=raw_speech, return_tensors="pt", max_length=300, padding="max_length")
|
||||
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 76)
|
||||
self.assertEqual(len(inputs[self.text_input_name][0]), 300)
|
||||
|
||||
@require_torch
|
||||
def test_doubly_passed_kwargs_audio(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
input_str = self.prepare_text_inputs(batch_size=3)
|
||||
raw_speech = floats_list((3, 1000))
|
||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||
with self.assertRaises(ValueError):
|
||||
_ = processor(
|
||||
text=input_str,
|
||||
audio=raw_speech,
|
||||
audio_kwargs={"padding": "max_length"},
|
||||
text_kwargs={"padding": "max_length"},
|
||||
padding="max_length",
|
||||
)
|
||||
|
||||
@@ -466,31 +452,27 @@ class ProcessorTesterMixin:
|
||||
def test_structured_kwargs_audio_nested(self):
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
if hasattr(self, "get_tokenizer"):
|
||||
tokenizer = self.get_tokenizer()
|
||||
elif hasattr(self, "get_component"):
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
if not tokenizer.pad_token:
|
||||
tokenizer.pad_token = "[TEST_PAD]"
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
tokenizer = self.get_component("tokenizer", max_length=117)
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = ["lower newer"]
|
||||
input_str = self.prepare_text_inputs(batch_size=3)
|
||||
raw_speech = floats_list((3, 1000))
|
||||
raw_speech = [np.asarray(audio) for audio in raw_speech]
|
||||
|
||||
# Define the kwargs for each modality
|
||||
all_kwargs = {
|
||||
"common_kwargs": {"return_tensors": "pt"},
|
||||
"text_kwargs": {"padding": "max_length", "max_length": 76},
|
||||
"audio_kwargs": {"padding": "max_length", "max_length": 66},
|
||||
"audio_kwargs": {"padding": "max_length", "max_length": 300},
|
||||
}
|
||||
|
||||
inputs = processor(text=input_str, audio=raw_speech, **all_kwargs)
|
||||
if "input_ids" in inputs:
|
||||
self.assertEqual(len(inputs["input_ids"][0]), 76)
|
||||
elif "labels" in inputs:
|
||||
self.assertEqual(len(inputs["labels"][0]), 76)
|
||||
self.assertEqual(len(inputs[self.text_input_name][0]), 76)
|
||||
|
||||
def test_tokenizer_defaults_preserved_by_kwargs_video(self):
|
||||
if "video_processor" not in self.processor_class.attributes:
|
||||
@@ -680,9 +662,10 @@ class ProcessorTesterMixin:
|
||||
|
||||
# TODO: the same test, but for audio + text processors that have strong overlap in kwargs
|
||||
# TODO (molbap) use the same structure of attribute kwargs for other tests to avoid duplication
|
||||
def test_overlapping_text_kwargs_handling(self):
|
||||
def test_overlapping_text_image_kwargs_handling(self):
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
processor_components = self.prepare_components()
|
||||
processor = self.processor_class(**processor_components)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
@@ -699,6 +682,28 @@ class ProcessorTesterMixin:
|
||||
text_kwargs={"padding": "do_not_pad"},
|
||||
)
|
||||
|
||||
def test_overlapping_text_audio_kwargs_handling(self):
|
||||
"""
|
||||
Checks that `padding`, or any other overlap arg between audio extractor and tokenizer
|
||||
is be passed to only text and ignored for audio for BC purposes
|
||||
"""
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
feature_extractor = self.get_component("feature_extractor")
|
||||
tokenizer = self.get_component("tokenizer")
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
|
||||
processor = self.processor_class(tokenizer=tokenizer, feature_extractor=feature_extractor, **processor_kwargs)
|
||||
self.skip_processor_without_typed_kwargs(processor)
|
||||
|
||||
input_str = self.prepare_text_inputs(batch_size=3)
|
||||
audio_lengths = [4000, 8000, 16000, 32000]
|
||||
raw_speech = [np.asarray(audio)[:length] for audio, length in zip(floats_list((3, 32_000)), audio_lengths)]
|
||||
|
||||
# padding = True should not raise an error and will if the audio processor popped its value to None
|
||||
_ = processor(text=input_str, audio=raw_speech, padding=True, return_tensors="pt")
|
||||
|
||||
def test_prepare_and_validate_optional_call_args(self):
|
||||
processor = self.get_processor()
|
||||
optional_call_args_name = getattr(processor, "optional_call_args", [])
|
||||
@@ -752,11 +757,14 @@ class ProcessorTesterMixin:
|
||||
# the reloaded tokenizer should get the chat template as well
|
||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||
|
||||
def test_chat_template_single(self):
|
||||
def test_image_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
@@ -797,11 +805,14 @@ class ProcessorTesterMixin:
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 1)
|
||||
|
||||
def test_chat_template_batched(self):
|
||||
def test_image_chat_template_batched(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
batched_messages = [
|
||||
[
|
||||
{
|
||||
@@ -864,11 +875,14 @@ class ProcessorTesterMixin:
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 2)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 2)
|
||||
|
||||
def test_chat_template_accepts_processing_kwargs(self):
|
||||
def test_image_chat_template_accepts_processing_kwargs(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
@@ -915,11 +929,14 @@ class ProcessorTesterMixin:
|
||||
self.assertLessEqual(out_dict[self.images_input_name][0][0].mean(), 0)
|
||||
|
||||
@require_torch
|
||||
def test_chat_template_dict_torch(self):
|
||||
def test_image_chat_template_dict_torch(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "image_processor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -1171,3 +1188,117 @@ class ProcessorTesterMixin:
|
||||
self.assertTrue("Dummy prompt for preprocess testing" in formatted_text)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name]), 1)
|
||||
self.assertEqual(len(out_dict_with_video[self.videos_input_name][0]), 243)
|
||||
|
||||
@require_librosa
|
||||
def test_audio_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio",
|
||||
},
|
||||
{"type": "text", "text": "What's that sound?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio",
|
||||
},
|
||||
{"type": "text", "text": "How about this one?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template([messages], add_generation_prompt=True, tokenize=False)
|
||||
self.assertEqual(len(formatted_prompt), 1) # batch size=1
|
||||
|
||||
formatted_prompt_tokenized = processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_tensors=None
|
||||
)
|
||||
expected_output = processor.tokenizer(formatted_prompt, return_tensors=None).input_ids
|
||||
self.assertListEqual(expected_output, formatted_prompt_tokenized)
|
||||
|
||||
messages[1]["content"][0]["audio"] = (
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
|
||||
)
|
||||
messages[3]["content"][0]["audio"] = (
|
||||
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3"
|
||||
)
|
||||
out_dict = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True)
|
||||
self.assertTrue(self.audio_input_name in out_dict)
|
||||
|
||||
# should always have input_ids and attention_mask
|
||||
self.assertEqual(len(out_dict["input_ids"]), 1) # batch-size=1
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1) # batch-size=1
|
||||
self.assertEqual(len(out_dict[self.audio_input_name]), 2) # 2 audios in the conversation
|
||||
|
||||
@require_torch
|
||||
@require_librosa
|
||||
def test_audio_chat_template_dict_torch(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
if "feature_extractor" not in self.processor_class.attributes:
|
||||
self.skipTest(f"feature_extractor attribute not present in {self.processor_class}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio",
|
||||
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
|
||||
},
|
||||
{"type": "text", "text": "What's that sound?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is the sound of glass shattering."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio",
|
||||
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/f2641_0_throatclearing.wav",
|
||||
},
|
||||
{"type": "text", "text": "How about this one?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
out_dict_tensors = processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
self.assertTrue(self.audio_input_name in out_dict_tensors)
|
||||
for k in out_dict_tensors:
|
||||
self.assertIsInstance(out_dict_tensors[k], torch.Tensor)
|
||||
|
||||
Reference in New Issue
Block a user