feat: Whisper prompting (#22496)
* initial working additions * clean and rename, add cond stripping initial prompt to decode * cleanup, edit create_initial_prompt_ids, add tests * repo consistency, flip order of conditional * fix error, move the processor fn to the tokenizer * repo consistency, update test ids to corresponding tokenizer * use convert_tokens_to_ids not get_vocab... * use actual conditional in generate * make sytle * initial address comments * initial working add new params to pipeline * first draft of sequential generation for condition_on_previous_text * add/update tests, make compatible with timestamps * make compatible with diff. input kwargs and max length * add None check * add temperature check * flip temp check operand * refocusing to prev pr scope * remove the params too * make style * edits, move max length incorporating prompt to whisper * address comments * remove asr pipeline prompt decoding, fix indexing * address comments (more tests, validate prompt) * un-comment out tests (from debug) * remove old comment * address comments * fix typo * remove timestamp token from test * make style * cleanup * copy method to fast tokenizer, set max_new_tokens for test * prompt_ids type just pt * address Amy's comments * make style
This commit is contained in:
@@ -1013,6 +1013,48 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
||||
encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state
|
||||
self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16))
|
||||
|
||||
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.arange(5)
|
||||
language = "<|de|>"
|
||||
task = "translate"
|
||||
lang_id = 6
|
||||
task_id = 7
|
||||
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||
|
||||
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
lang_id,
|
||||
task_id,
|
||||
]
|
||||
for row in output.tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = WhisperForConditionalGeneration(config).eval().to(torch_device)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.asarray(range(5))
|
||||
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||
|
||||
output = model.generate(
|
||||
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||
)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in forced_decoder_ids],
|
||||
]
|
||||
for row in output.tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@@ -1429,6 +1471,60 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
# fmt: on
|
||||
self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(4)[-1:]
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||
|
||||
output_without_prompt = model.generate(input_features)
|
||||
prompt_ids = processor.get_prompt_ids("Leighton")
|
||||
output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids)
|
||||
|
||||
expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
|
||||
expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>"
|
||||
self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt)
|
||||
self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt)
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||
task = "translate"
|
||||
language = "de"
|
||||
expected_tokens = [f"<|{task}|>", f"<|{language}|>"]
|
||||
prompt = "test prompt"
|
||||
prompt_ids = processor.get_prompt_ids(prompt)
|
||||
|
||||
output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids)
|
||||
text = processor.decode(output[0])
|
||||
|
||||
self.assertTrue(prompt in text)
|
||||
self.assertTrue(all([token in text for token in expected_tokens]))
|
||||
|
||||
@slow
|
||||
def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self):
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.to(torch_device)
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor(input_speech, return_tensors="pt").input_features
|
||||
prompt = "test prompt"
|
||||
prompt_ids = processor.get_prompt_ids(prompt)
|
||||
|
||||
model.generation_config.forced_decoder_ids = None
|
||||
model.config.forced_decoder_ids = None
|
||||
|
||||
output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True)
|
||||
text = processor.decode(output[0])
|
||||
|
||||
self.assertTrue(prompt in text)
|
||||
|
||||
|
||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||
if head_mask is None:
|
||||
|
||||
@@ -16,6 +16,8 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from transformers import WhisperTokenizer, is_speech_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio
|
||||
|
||||
@@ -146,3 +148,32 @@ class WhisperProcessorTest(unittest.TestCase):
|
||||
|
||||
expected_ids = [TRANSCRIBE, NOTIMESTAMPS]
|
||||
self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids)
|
||||
|
||||
def test_get_prompt_ids(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
prompt_ids = processor.get_prompt_ids("Mr. Quilter")
|
||||
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||
|
||||
self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353])
|
||||
self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter")
|
||||
|
||||
def test_empty_get_prompt_ids(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
prompt_ids = processor.get_prompt_ids("")
|
||||
decoded_prompt = processor.tokenizer.decode(prompt_ids)
|
||||
|
||||
self.assertListEqual(prompt_ids.tolist(), [50360, 220])
|
||||
self.assertEqual(decoded_prompt, "<|startofprev|> ")
|
||||
|
||||
def test_get_prompt_ids_with_special_tokens(self):
|
||||
processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor())
|
||||
|
||||
def _test_prompt_error_raised_helper(prompt, special_token):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
processor.get_prompt_ids(prompt)
|
||||
expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}."
|
||||
self.assertEqual(expected, str(excinfo.value))
|
||||
|
||||
_test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>")
|
||||
_test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>")
|
||||
_test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>")
|
||||
|
||||
@@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
|
||||
def test_skip_special_tokens_skips_prompt_ids(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
# fmt: off
|
||||
encoded_input = [
|
||||
50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359,
|
||||
50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13,
|
||||
2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257,
|
||||
]
|
||||
# fmt: on
|
||||
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
|
||||
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
|
||||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
|
||||
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
|
||||
)
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
Reference in New Issue
Block a user