Support return_tensors in audio chat templates (#34601)
* add audio chat templates * update * update * nit * green ci * we dont care about the order anymore * clean up after rebase * overriden tests rename * rename shieldgemma also * one more rename * require_read_token * removde images/videos * retrigger CI flaky
This commit is contained in:
committed by
GitHub
parent
19085c28da
commit
0f733110a6
@@ -238,7 +238,7 @@ And who is that?<|im_end|>
|
||||
self.assertEqual(rendered, expected_rendered)
|
||||
|
||||
# Override as AriaImageProcessor doesn't accept `do_rescale`
|
||||
def test_chat_template_accepts_processing_kwargs(self):
|
||||
def test_image_chat_template_accepts_processing_kwargs(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
@@ -116,7 +116,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
def test_chat_template_single(self):
|
||||
def test_image_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
@@ -154,7 +154,7 @@ class Qwen2_5_VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
|
||||
|
||||
def test_chat_template_batched(self):
|
||||
def test_image_chat_template_batched(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
@@ -11,20 +11,63 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer, Qwen2AudioProcessor, WhisperFeatureExtractor
|
||||
from transformers.testing_utils import require_torch, require_torchaudio
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
from ...test_processing_common import ProcessorTesterMixin
|
||||
|
||||
|
||||
if is_torch_available:
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
class Qwen2AudioProcessorTest(unittest.TestCase):
|
||||
class Qwen2AudioProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Qwen2AudioProcessor
|
||||
|
||||
def setUp(self):
|
||||
self.checkpoint = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
|
||||
processor_kwargs = self.prepare_processor_dict()
|
||||
processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint, **processor_kwargs)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_tokenizer(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
|
||||
|
||||
def get_audio_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).audio_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_processor_dict(self):
|
||||
return {
|
||||
"chat_template": "{% set audio_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if 'audio' in content or 'audio_url' in content or message['type'] == 'audio' %}{% set audio_count.value = audio_count.value + 1 %}Audio {{ audio_count.value }}: <|audio_bos|><|AUDIO|><|audio_eos|>\n{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}",
|
||||
}
|
||||
|
||||
# Override as Qwen2AudioProcessor needs audio tokens in prompts
|
||||
def prepare_text_inputs(self, batch_size: Optional[int] = None):
|
||||
if batch_size is None:
|
||||
return "lower newer <|AUDIO|>"
|
||||
|
||||
if batch_size < 1:
|
||||
raise ValueError("batch_size must be greater than 0")
|
||||
|
||||
if batch_size == 1:
|
||||
return ["lower newer <|AUDIO|>"]
|
||||
return ["lower newer <|AUDIO|>", "<|AUDIO|> upper older longer string"] + ["<|AUDIO|> lower newer"] * (
|
||||
batch_size - 2
|
||||
)
|
||||
|
||||
def test_can_load_various_tokenizers(self):
|
||||
processor = Qwen2AudioProcessor.from_pretrained(self.checkpoint)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
|
||||
@@ -77,7 +120,7 @@ class Qwen2AudioProcessorTest(unittest.TestCase):
|
||||
"assistant",
|
||||
"Ċ",
|
||||
]
|
||||
print(slow_tokenizer.tokenize(prompt))
|
||||
|
||||
self.assertEqual(slow_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
|
||||
self.assertEqual(fast_tokenizer.tokenize(prompt), EXPECTED_OUTPUT)
|
||||
|
||||
@@ -110,5 +153,31 @@ class Qwen2AudioProcessorTest(unittest.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
self.assertEqual(expected_prompt, formatted_prompt)
|
||||
|
||||
def test_chat_template_with_continue_final_message(self):
|
||||
processor = AutoProcessor.from_pretrained(self.checkpoint)
|
||||
expected_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of " # fmt: skip
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": "You are a helpful assistant."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "audio",
|
||||
"audio": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3",
|
||||
},
|
||||
{"type": "text", "text": "What's that sound?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is the sound of "}],
|
||||
},
|
||||
]
|
||||
prompt = processor.apply_chat_template(messages, continue_final_message=True)
|
||||
self.assertEqual(expected_prompt, prompt)
|
||||
|
||||
@@ -113,7 +113,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
|
||||
|
||||
def test_chat_template_single(self):
|
||||
def test_image_chat_template_single(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
@@ -151,7 +151,7 @@ class Qwen2VLProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(len(out_dict["attention_mask"]), 1)
|
||||
self.assertEqual(len(out_dict[self.images_input_name]), 71280)
|
||||
|
||||
def test_chat_template_batched(self):
|
||||
def test_image_chat_template_batched(self):
|
||||
processor = self.get_processor()
|
||||
if processor.chat_template is None:
|
||||
self.skipTest("Processor has no chat template")
|
||||
|
||||
@@ -166,22 +166,22 @@ class ShieldGemma2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
|
||||
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
|
||||
def test_chat_template_accepts_processing_kwargs(self):
|
||||
def test_image_chat_template_accepts_processing_kwargs(self):
|
||||
pass
|
||||
|
||||
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
|
||||
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
|
||||
def test_chat_template_batched(self):
|
||||
def test_image_chat_template_batched(self):
|
||||
pass
|
||||
|
||||
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
|
||||
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
|
||||
def test_chat_template_dict_torch(self):
|
||||
def test_image_chat_template_dict_torch(self):
|
||||
pass
|
||||
|
||||
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
|
||||
@unittest.skip("ShieldGemma 2 chat template requires different message structure from parent.")
|
||||
def test_chat_template_single(self):
|
||||
def test_image_chat_template_single(self):
|
||||
pass
|
||||
|
||||
# TODO(ryanmullins): Adapt this test for ShieldGemma 2
|
||||
|
||||
@@ -18,8 +18,6 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.utils import FEATURE_EXTRACTOR_NAME
|
||||
@@ -30,6 +28,8 @@ from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Wav2Vec2Processor
|
||||
audio_input_name = "input_values"
|
||||
text_input_name = "labels"
|
||||
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
@@ -132,22 +132,6 @@ class Wav2Vec2ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_padding_argument_not_ignored(self):
|
||||
# padding, or any other overlap arg between audio extractor and tokenizer
|
||||
# should be passed to both text and audio and not ignored
|
||||
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
batch_duration_in_seconds = [1, 3, 2, 6]
|
||||
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
|
||||
|
||||
# padding = True should not raise an error and will if the audio processor popped its value to None
|
||||
_ = processor(
|
||||
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
@@ -18,8 +18,6 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.models.seamless_m4t import SeamlessM4TFeatureExtractor
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
@@ -32,6 +30,7 @@ from ..wav2vec2.test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
processor_class = Wav2Vec2BertProcessor
|
||||
text_input_name = "labels"
|
||||
|
||||
def setUp(self):
|
||||
vocab = "<pad> <s> </s> <unk> | E T A O N I H S R D L U M W C F G Y P B V K ' X J Q Z".split(" ")
|
||||
@@ -136,22 +135,6 @@ class Wav2Vec2BertProcessorTest(ProcessorTesterMixin, unittest.TestCase):
|
||||
for key in encoded_tok.keys():
|
||||
self.assertListEqual(encoded_tok[key], encoded_processor[key])
|
||||
|
||||
def test_padding_argument_not_ignored(self):
|
||||
# padding, or any other overlap arg between audio extractor and tokenizer
|
||||
# should be passed to both text and audio and not ignored
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
processor = Wav2Vec2BertProcessor(tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
batch_duration_in_seconds = [1, 3, 2, 6]
|
||||
input_features = [np.random.random(16_000 * s) for s in batch_duration_in_seconds]
|
||||
|
||||
# padding = True should not raise an error and will if the audio processor popped its value to None
|
||||
# processor(input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt")
|
||||
_ = processor(
|
||||
input_features, padding=True, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
def test_tokenizer_decode(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
Reference in New Issue
Block a user