From 2898fd396831bccd71e4e7056c2ed816c4a11406 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 23 Jun 2023 14:27:49 +0200 Subject: [PATCH] Fix some `TFWhisperModelIntegrationTests` (#24428) * fix * fix * fix * fix * fix * fix * fix * fix * fix * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix * fix * fix --------- Co-authored-by: ydshieh Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/whisper/modeling_tf_whisper.py | 227 +++++++++++++++++- .../whisper/test_modeling_tf_whisper.py | 64 ++++- 2 files changed, 278 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index 4d6ecb85b5..c36450dd12 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -19,12 +19,14 @@ from __future__ import annotations import math import random -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import numpy as np import tensorflow as tf from ...activations_tf import get_tf_activation +from ...generation.configuration_utils import GenerationConfig +from ...generation.tf_logits_process import TFLogitsProcessorList from ...modeling_tf_outputs import ( TFBaseModelOutput, TFBaseModelOutputWithPastAndCrossAttentions, @@ -41,6 +43,7 @@ from ...modeling_tf_utils import ( from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_whisper import WhisperConfig +from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE logger = logging.get_logger(__name__) @@ -1324,6 +1327,228 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua encoder_attentions=outputs.encoder_attentions, ) + def generate( + self, + inputs: Optional[tf.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[TFLogitsProcessorList] = None, + seed: Optional[List[int]] = None, + return_timestamps: Optional[bool] = None, + task: Optional[str] = None, + language: Optional[str] = None, + is_multilingual: Optional[bool] = None, + prompt_ids: Optional[tf.Tensor] = None, + return_token_timestamps=None, + **kwargs, + ): + r""" + Generates sequences of token ids for models with a language modeling head. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](../generation_strategies). + + + + Parameters: + inputs (`tf.Tensor` of varying shape depending on the modality, *optional*): + The sequence used as a prompt for the generation or as model inputs to the encoder. If unset the method + initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in + the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, + `input_values`, `input_features`, or `pixel_values`. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + seed (`List[int]`, *optional*): + Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the + `seed` argument from stateless functions in `tf.random`. + return_timestamps (`bool`, *optional*): + Whether to return the timestamps with the text. This enables the `TFWhisperTimestampsLogitsProcessor`. + task (`str`, *optional*): + Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` + will be updated accordingly. + language (`str`, *optional*): + Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can + find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. + is_multilingual (`bool`, *optional*): + Whether or not the model is multilingual. + prompt_ids (`tf.Tensor`, *optional*): + Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is + provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for + transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words + correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. + return_token_timestamps (`bool`, *optional*): + Whether to return token-level timestamps with the text. This can be used with or without the + `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into + words. + kwargs: + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when + `config.return_dict_in_generate=True`) or a `tf.Tensor`. + + If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.TFGreedySearchDecoderOnlyOutput`], + - [`~generation.TFSampleDecoderOnlyOutput`], + - [`~generation.TFBeamSearchDecoderOnlyOutput`], + - [`~generation.TFBeamSampleDecoderOnlyOutput`] + + If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible + [`~utils.ModelOutput`] types are: + + - [`~generation.TFGreedySearchEncoderDecoderOutput`], + - [`~generation.TFSampleEncoderDecoderOutput`], + - [`~generation.TFBeamSearchEncoderDecoderOutput`], + - [`~generation.TFBeamSampleEncoderDecoderOutput`] + + """ + if generation_config is None: + generation_config = self.generation_config + + if return_timestamps is not None: + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You are trying to return timestamps, but the generation config is not properly set." + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`." + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + ) + + generation_config.return_timestamps = return_timestamps + else: + generation_config.return_timestamps = False + + if language is not None: + language = language.lower() + generation_config.language = language + if task is not None: + generation_config.task = task + + forced_decoder_ids = None + + # Legacy code for backward compatibility + if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: + forced_decoder_ids = self.config.forced_decoder_ids + elif ( + hasattr(self.generation_config, "forced_decoder_ids") + and self.generation_config.forced_decoder_ids is not None + ): + forced_decoder_ids = self.generation_config.forced_decoder_ids + else: + forced_decoder_ids = kwargs.get("forced_decoder_ids", None) + + if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): + forced_decoder_ids = [] + if hasattr(generation_config, "language"): + if generation_config.language in generation_config.lang_to_id.keys(): + language_token = generation_config.language + elif generation_config.language in TO_LANGUAGE_CODE.keys(): + language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" + elif generation_config.language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{generation_config.language}|>" + else: + is_language_code = len(generation_config.language) == 2 + raise ValueError( + f"Unsupported language: {generation_config.language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." + ) + forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) + else: + forced_decoder_ids.append((1, None)) # automatically detect the language + + if hasattr(generation_config, "task"): + if generation_config.task in TASK_IDS: + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + raise ValueError( + f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" + ) + elif hasattr(generation_config, "task_to_id"): + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe + if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if forced_decoder_ids is not None: + generation_config.forced_decoder_ids = forced_decoder_ids + + if prompt_ids is not None: + if kwargs.get("decoder_start_token_id") is not None: + raise ValueError( + "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." + ) + prompt_ids = prompt_ids.tolist() + decoder_start_token_id, *text_prompt_ids = prompt_ids + # Slicing the text prompt ids in a manner consistent with the OpenAI implementation + # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) + text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :] + # Set the decoder_start_token_id to <|startofprev|> + kwargs.update({"decoder_start_token_id": decoder_start_token_id}) + + # Update the max generation length to include the prompt + specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None) + default_max_length = generation_config.max_new_tokens or generation_config.max_length + non_prompt_max_length = specified_max_length or default_max_length + kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids) + + # Reformat the forced_decoder_ids to incorporate the prompt + non_prompt_forced_decoder_ids = ( + kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids + ) + forced_decoder_ids = [ + *text_prompt_ids, + generation_config.decoder_start_token_id, + *[token for _rank, token in non_prompt_forced_decoder_ids], + ] + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] + generation_config.forced_decoder_ids = forced_decoder_ids + + # TODO: Implement `WhisperTimeStampLogitsProcessor`. + if generation_config.return_timestamps: + # logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)] + raise ValueError("`TFWhisperForConditionalGeneration` doesn't support returning the timestamps yet.") + + if return_token_timestamps: + kwargs["output_attentions"] = True + kwargs["return_dict_in_generate"] = True + + if getattr(generation_config, "task", None) == "translate": + logger.warning("Token-level timestamps may not be reliable for task 'translate'.") + if not hasattr(generation_config, "alignment_heads"): + raise ValueError( + "Model generation config has no `alignment_heads`, token-level timestamps not available. " + "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." + ) + + outputs = super().generate( + inputs, + generation_config, + logits_processor, + **kwargs, + ) + + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads) + + return outputs + def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None diff --git a/tests/models/whisper/test_modeling_tf_whisper.py b/tests/models/whisper/test_modeling_tf_whisper.py index 0783bd67bf..1bf5c2ccc2 100644 --- a/tests/models/whisper/test_modeling_tf_whisper.py +++ b/tests/models/whisper/test_modeling_tf_whisper.py @@ -634,6 +634,48 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC generated_ids = output_tokens[:, input_features.shape[-1] :] self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids)) + def test_generate_with_prompt_ids_and_task_and_language(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = TFWhisperForConditionalGeneration(config) + input_features = input_dict["input_features"] + prompt_ids = np.arange(5) + language = "<|de|>" + task = "translate" + lang_id = 6 + task_id = 7 + model.generation_config.__setattr__("lang_to_id", {language: lang_id}) + model.generation_config.__setattr__("task_to_id", {task: task_id}) + + output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids) + + expected_output_start = [ + *prompt_ids.tolist(), + model.generation_config.decoder_start_token_id, + lang_id, + task_id, + ] + for row in output.numpy().tolist(): + self.assertListEqual(row[: len(expected_output_start)], expected_output_start) + + def test_generate_with_prompt_ids_and_forced_decoder_ids(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = TFWhisperForConditionalGeneration(config) + input_features = input_dict["input_features"] + prompt_ids = np.asarray(range(5)) + forced_decoder_ids = [(1, 6), (2, 7), (3, 8)] + + output = model.generate( + input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids + ) + + expected_output_start = [ + *prompt_ids.tolist(), + model.generation_config.decoder_start_token_id, + *[token for _rank, token in forced_decoder_ids], + ] + for row in output.numpy().tolist(): + self.assertListEqual(row[: len(expected_output_start)], expected_output_start) + def _load_datasamples(num_samples): ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") @@ -779,24 +821,22 @@ def _test_large_batched_generation(in_queue, out_queue, timeout): generated_ids = np.concatenate([generated_ids_1, generated_ids_2]) # fmt: off - EXPECTED_LOGITS = tf.convert_to_tensor( - [ - [50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404], - [50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257], - [50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904], - [50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439] - ] - ) + EXPECTED_IDS = [ + [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], + [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], + [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], + [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + ] # fmt: on - unittest.TestCase().assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) + unittest.TestCase().assertEqual(generated_ids.tolist(), EXPECTED_IDS) # fmt: off EXPECTED_TRANSCRIPT = [ - " Mr. Quilter is the apostle of the middle classes and we are glad", + " Mr. Quilter is the apostle of the middle classes and we are glad to", " Nor is Mr. Quilter's manner less interesting than his matter.", - " He tells us that at this festive season of the year, with Christmas and roast", - " He has grave doubts whether Sir Frederick Layton's work is really Greek after all", + " He tells us that at this festive season of the year, with Christmas and roast beef", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all," ] # fmt: on