diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 91de6810b1..96f91a0a43 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -34,7 +34,12 @@ from ...modeling_outputs import ( SequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_whisper import WhisperConfig from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE @@ -1464,6 +1469,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): task=None, language=None, is_multilingual=None, + prompt_ids: Optional[torch.Tensor] = None, **kwargs, ): """ @@ -1521,6 +1527,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. is_multilingual (`bool`, *optional*): Whether or not the model is multilingual. + prompt_ids (`torch.Tensor`, *optional*): + Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is + provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for + transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words + correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1567,8 +1578,21 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): if task is not None: generation_config.task = task - forced_decoder_ids = [] - if task is not None or language is not None: + forced_decoder_ids = None + + # Legacy code for backward compatibility + if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: + forced_decoder_ids = self.config.forced_decoder_ids + elif ( + hasattr(self.generation_config, "forced_decoder_ids") + and self.generation_config.forced_decoder_ids is not None + ): + forced_decoder_ids = self.generation_config.forced_decoder_ids + else: + forced_decoder_ids = kwargs.get("forced_decoder_ids", None) + + if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): + forced_decoder_ids = [] if hasattr(generation_config, "language"): if generation_config.language in generation_config.lang_to_id.keys(): language_token = generation_config.language @@ -1593,27 +1617,48 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): raise ValueError( f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" ) - else: + elif hasattr(generation_config, "task_to_id"): forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - # Legacy code for backward compatibility - elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: - forced_decoder_ids = self.config.forced_decoder_ids - elif ( - hasattr(self.generation_config, "forced_decoder_ids") - and self.generation_config.forced_decoder_ids is not None - ): - forced_decoder_ids = self.generation_config.forced_decoder_ids + if forced_decoder_ids is not None: + generation_config.forced_decoder_ids = forced_decoder_ids + + if prompt_ids is not None: + if kwargs.get("decoder_start_token_id") is not None: + raise ValueError( + "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." + ) + prompt_ids = prompt_ids.tolist() + decoder_start_token_id, *text_prompt_ids = prompt_ids + # Set the decoder_start_token_id to <|startofprev|> + kwargs.update({"decoder_start_token_id": decoder_start_token_id}) + + # Update the max generation length to include the prompt + specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None) + default_max_length = generation_config.max_new_tokens or generation_config.max_length + non_prompt_max_length = specified_max_length or default_max_length + kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids) + + # Reformat the forced_decoder_ids to incorporate the prompt + non_prompt_forced_decoder_ids = ( + kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids + ) + forced_decoder_ids = [ + # Slicing the text prompt ids in a manner consistent with the OpenAI implementation + # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) + *text_prompt_ids[-self.config.max_length // 2 - 1 :], + generation_config.decoder_start_token_id, + *[token for _rank, token in non_prompt_forced_decoder_ids], + ] + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] + generation_config.forced_decoder_ids = forced_decoder_ids if generation_config.return_timestamps: logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] - if len(forced_decoder_ids) > 0: - generation_config.forced_decoder_ids = forced_decoder_ids - return super().generate( inputs, generation_config, diff --git a/src/transformers/models/whisper/processing_whisper.py b/src/transformers/models/whisper/processing_whisper.py index 8c158b041f..b0d0d6c954 100644 --- a/src/transformers/models/whisper/processing_whisper.py +++ b/src/transformers/models/whisper/processing_whisper.py @@ -16,6 +16,7 @@ Speech processor class for Whisper """ + from ...processing_utils import ProcessorMixin @@ -91,3 +92,6 @@ class WhisperProcessor(ProcessorMixin): the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) + + def get_prompt_ids(self, text: str, return_tensors="np"): + return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 24eb72a0b0..4c7c9c89fd 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -606,6 +606,11 @@ class WhisperTokenizer(PreTrainedTokenizer): ) -> str: self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False) + if skip_special_tokens: + prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") + decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") + token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id) + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) # To avoid mixing byte-level and unicode for byte-level BPT @@ -714,6 +719,31 @@ class WhisperTokenizer(PreTrainedTokenizer): time_precision=time_precision, ) + def get_prompt_ids(self, text: str, return_tensors="np"): + """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" + batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False) + + # Check for special tokens + prompt_text_ids = batch_encoding["input_ids"][1:] + special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None) + if special_token_id is not None: + token = self.convert_ids_to_tokens(special_token_id) + raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.") + + batch_encoding.convert_to_tensors(tensor_type=return_tensors) + return batch_encoding["input_ids"] + + @staticmethod + def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): + has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id + if has_prompt: + if decoder_start_token_id in token_ids: + return token_ids[token_ids.index(decoder_start_token_id) :] + else: + return [] + + return token_ids + def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): """ diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index fb1bf89ed6..be4ad842a7 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -312,6 +312,11 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): return text def _decode(self, *args, normalize: bool = False, **kwargs) -> str: + if kwargs["skip_special_tokens"]: + prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>") + decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>") + kwargs["token_ids"] = self._strip_prompt(kwargs["token_ids"], prompt_token_id, decoder_start_token_id) + text = super()._decode(*args, **kwargs) if normalize: @@ -485,3 +490,30 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): return_language=return_language, time_precision=time_precision, ) + + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer.get_prompt_ids + def get_prompt_ids(self, text: str, return_tensors="np"): + """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`].""" + batch_encoding = self("<|startofprev|>", text.strip(), add_prefix_space=True, add_special_tokens=False) + + # Check for special tokens + prompt_text_ids = batch_encoding["input_ids"][1:] + special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None) + if special_token_id is not None: + token = self.convert_ids_to_tokens(special_token_id) + raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.") + + batch_encoding.convert_to_tensors(tensor_type=return_tensors) + return batch_encoding["input_ids"] + + @staticmethod + # Copied from transformers.models.whisper.tokenization_whisper.WhisperTokenizer._strip_prompt + def _strip_prompt(token_ids: List[int], prompt_token_id: int, decoder_start_token_id: int): + has_prompt = isinstance(token_ids, list) and token_ids and token_ids[0] == prompt_token_id + if has_prompt: + if decoder_start_token_id in token_ids: + return token_ids[token_ids.index(decoder_start_token_id) :] + else: + return [] + + return token_ids diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 883a2021b9..98bbbb3214 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1013,6 +1013,48 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi encoder_last_hidden_state = model(**input_dict).encoder_last_hidden_state self.assertTrue(encoder_last_hidden_state.shape, (13, 30, 16)) + def test_generate_with_prompt_ids_and_task_and_language(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = WhisperForConditionalGeneration(config).eval().to(torch_device) + input_features = input_dict["input_features"] + prompt_ids = np.arange(5) + language = "<|de|>" + task = "translate" + lang_id = 6 + task_id = 7 + model.generation_config.__setattr__("lang_to_id", {language: lang_id}) + model.generation_config.__setattr__("task_to_id", {task: task_id}) + + output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids) + + expected_output_start = [ + *prompt_ids.tolist(), + model.generation_config.decoder_start_token_id, + lang_id, + task_id, + ] + for row in output.tolist(): + self.assertListEqual(row[: len(expected_output_start)], expected_output_start) + + def test_generate_with_prompt_ids_and_forced_decoder_ids(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = WhisperForConditionalGeneration(config).eval().to(torch_device) + input_features = input_dict["input_features"] + prompt_ids = np.asarray(range(5)) + forced_decoder_ids = [(1, 6), (2, 7), (3, 8)] + + output = model.generate( + input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids + ) + + expected_output_start = [ + *prompt_ids.tolist(), + model.generation_config.decoder_start_token_id, + *[token for _rank, token in forced_decoder_ids], + ] + for row in output.tolist(): + self.assertListEqual(row[: len(expected_output_start)], expected_output_start) + @require_torch @require_torchaudio @@ -1429,6 +1471,60 @@ class WhisperModelIntegrationTests(unittest.TestCase): # fmt: on self.assertTrue(torch.allclose(logits[0][0, 0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4)) + @slow + def test_generate_with_prompt_ids(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + input_speech = self._load_datasamples(4)[-1:] + input_features = processor(input_speech, return_tensors="pt").input_features + + output_without_prompt = model.generate(input_features) + prompt_ids = processor.get_prompt_ids("Leighton") + output_with_prompt = model.generate(input_features, prompt_ids=prompt_ids) + + expected_without_prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>" + expected_with_prompt = "<|startofprev|> Leighton<|startoftranscript|><|en|><|transcribe|><|notimestamps|> He has grave doubts whether Sir Frederick Leighton's work is really Greek after all and can discover in it but little of Rocky Ithaca.<|endoftext|>" + self.assertEqual(processor.decode(output_without_prompt[0]), expected_without_prompt) + self.assertEqual(processor.decode(output_with_prompt[0]), expected_with_prompt) + + @slow + def test_generate_with_prompt_ids_and_forced_decoder_ids(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + input_speech = self._load_datasamples(1) + input_features = processor(input_speech, return_tensors="pt").input_features + task = "translate" + language = "de" + expected_tokens = [f"<|{task}|>", f"<|{language}|>"] + prompt = "test prompt" + prompt_ids = processor.get_prompt_ids(prompt) + + output = model.generate(input_features, task=task, language=language, prompt_ids=prompt_ids) + text = processor.decode(output[0]) + + self.assertTrue(prompt in text) + self.assertTrue(all([token in text for token in expected_tokens])) + + @slow + def test_generate_with_prompt_ids_and_no_non_prompt_forced_decoder_ids(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model.to(torch_device) + input_speech = self._load_datasamples(1) + input_features = processor(input_speech, return_tensors="pt").input_features + prompt = "test prompt" + prompt_ids = processor.get_prompt_ids(prompt) + + model.generation_config.forced_decoder_ids = None + model.config.forced_decoder_ids = None + + output = model.generate(input_features, prompt_ids=prompt_ids, return_timestamps=True) + text = processor.decode(output[0]) + + self.assertTrue(prompt in text) + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: diff --git a/tests/models/whisper/test_processor_whisper.py b/tests/models/whisper/test_processor_whisper.py index b844d433ed..e96f4260e9 100644 --- a/tests/models/whisper/test_processor_whisper.py +++ b/tests/models/whisper/test_processor_whisper.py @@ -16,6 +16,8 @@ import shutil import tempfile import unittest +import pytest + from transformers import WhisperTokenizer, is_speech_available from transformers.testing_utils import require_sentencepiece, require_torch, require_torchaudio @@ -146,3 +148,32 @@ class WhisperProcessorTest(unittest.TestCase): expected_ids = [TRANSCRIBE, NOTIMESTAMPS] self.assertListEqual([ids[-1] for ids in forced_decoder_ids], expected_ids) + + def test_get_prompt_ids(self): + processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + prompt_ids = processor.get_prompt_ids("Mr. Quilter") + decoded_prompt = processor.tokenizer.decode(prompt_ids) + + self.assertListEqual(prompt_ids.tolist(), [50360, 1770, 13, 2264, 346, 353]) + self.assertEqual(decoded_prompt, "<|startofprev|> Mr. Quilter") + + def test_empty_get_prompt_ids(self): + processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + prompt_ids = processor.get_prompt_ids("") + decoded_prompt = processor.tokenizer.decode(prompt_ids) + + self.assertListEqual(prompt_ids.tolist(), [50360, 220]) + self.assertEqual(decoded_prompt, "<|startofprev|> ") + + def test_get_prompt_ids_with_special_tokens(self): + processor = WhisperProcessor(tokenizer=self.get_tokenizer(), feature_extractor=self.get_feature_extractor()) + + def _test_prompt_error_raised_helper(prompt, special_token): + with pytest.raises(ValueError) as excinfo: + processor.get_prompt_ids(prompt) + expected = f"Encountered text in the prompt corresponding to disallowed special token: {special_token}." + self.assertEqual(expected, str(excinfo.value)) + + _test_prompt_error_raised_helper("<|startofprev|> test", "<|startofprev|>") + _test_prompt_error_raised_helper("test <|notimestamps|>", "<|notimestamps|>") + _test_prompt_error_raised_helper("test <|zh|> test <|transcribe|>", "<|zh|>") diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 9ceef149fa..5022d29b73 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -194,6 +194,25 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): merge = _find_longest_common_sequence([seq1, seq2, seq3]) self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8]) + def test_skip_special_tokens_skips_prompt_ids(self): + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + # fmt: off + encoded_input = [ + 50361, 2221, 13, 2326, 388, 391, 50258, 50259, 50359, + 50363, 1282, 264, 2674, 9156, 295, 1523, 11, 2221, 13, + 2326, 388, 391, 13657, 365, 2681, 21296, 17711, 13, 50257, + ] + # fmt: on + expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>" + expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity." + self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens) + self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens) + self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens) + self.assertEqual( + rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens + ) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en"