Fix some TFWhisperModelIntegrationTests (#24428)
* fix * fix * fix * fix * fix * fix * fix * fix * fix * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -19,12 +19,14 @@ from __future__ import annotations
|
||||
|
||||
import math
|
||||
import random
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...generation.configuration_utils import GenerationConfig
|
||||
from ...generation.tf_logits_process import TFLogitsProcessorList
|
||||
from ...modeling_tf_outputs import (
|
||||
TFBaseModelOutput,
|
||||
TFBaseModelOutputWithPastAndCrossAttentions,
|
||||
@@ -41,6 +43,7 @@ from ...modeling_tf_utils import (
|
||||
from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
|
||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
from .configuration_whisper import WhisperConfig
|
||||
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -1324,6 +1327,228 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
|
||||
encoder_attentions=outputs.encoder_attentions,
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[tf.Tensor] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
logits_processor: Optional[TFLogitsProcessorList] = None,
|
||||
seed: Optional[List[int]] = None,
|
||||
return_timestamps: Optional[bool] = None,
|
||||
task: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
is_multilingual: Optional[bool] = None,
|
||||
prompt_ids: Optional[tf.Tensor] = None,
|
||||
return_token_timestamps=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
||||
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
||||
|
||||
For an overview of generation strategies and code examples, check out the [following
|
||||
guide](../generation_strategies).
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
inputs (`tf.Tensor` of varying shape depending on the modality, *optional*):
|
||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If unset the method
|
||||
initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in
|
||||
the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`,
|
||||
`input_values`, `input_features`, or `pixel_values`.
|
||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
||||
passed to generate matching the attributes of `generation_config` will override them. If
|
||||
`generation_config` is not provided, the default will be used, which had the following loading
|
||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
||||
default values, whose documentation should be checked to parameterize generation.
|
||||
logits_processor (`LogitsProcessorList`, *optional*):
|
||||
Custom logits processors that complement the default logits processors built from arguments and
|
||||
generation config. If a logit processor is passed that is already created with the arguments or a
|
||||
generation config an error is thrown. This feature is intended for advanced users.
|
||||
seed (`List[int]`, *optional*):
|
||||
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
|
||||
`seed` argument from stateless functions in `tf.random`.
|
||||
return_timestamps (`bool`, *optional*):
|
||||
Whether to return the timestamps with the text. This enables the `TFWhisperTimestampsLogitsProcessor`.
|
||||
task (`str`, *optional*):
|
||||
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
|
||||
will be updated accordingly.
|
||||
language (`str`, *optional*):
|
||||
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
|
||||
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
|
||||
is_multilingual (`bool`, *optional*):
|
||||
Whether or not the model is multilingual.
|
||||
prompt_ids (`tf.Tensor`, *optional*):
|
||||
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
|
||||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||
return_token_timestamps (`bool`, *optional*):
|
||||
Whether to return token-level timestamps with the text. This can be used with or without the
|
||||
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
||||
words.
|
||||
kwargs:
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
|
||||
`config.return_dict_in_generate=True`) or a `tf.Tensor`.
|
||||
|
||||
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.TFGreedySearchDecoderOnlyOutput`],
|
||||
- [`~generation.TFSampleDecoderOnlyOutput`],
|
||||
- [`~generation.TFBeamSearchDecoderOnlyOutput`],
|
||||
- [`~generation.TFBeamSampleDecoderOnlyOutput`]
|
||||
|
||||
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.TFGreedySearchEncoderDecoderOutput`],
|
||||
- [`~generation.TFSampleEncoderDecoderOutput`],
|
||||
- [`~generation.TFBeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.TFBeamSampleEncoderDecoderOutput`]
|
||||
|
||||
"""
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
|
||||
if return_timestamps is not None:
|
||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You are trying to return timestamps, but the generation config is not properly set."
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
)
|
||||
|
||||
generation_config.return_timestamps = return_timestamps
|
||||
else:
|
||||
generation_config.return_timestamps = False
|
||||
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
generation_config.language = language
|
||||
if task is not None:
|
||||
generation_config.task = task
|
||||
|
||||
forced_decoder_ids = None
|
||||
|
||||
# Legacy code for backward compatibility
|
||||
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
elif (
|
||||
hasattr(self.generation_config, "forced_decoder_ids")
|
||||
and self.generation_config.forced_decoder_ids is not None
|
||||
):
|
||||
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
||||
else:
|
||||
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
|
||||
|
||||
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
|
||||
forced_decoder_ids = []
|
||||
if hasattr(generation_config, "language"):
|
||||
if generation_config.language in generation_config.lang_to_id.keys():
|
||||
language_token = generation_config.language
|
||||
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
||||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
||||
elif generation_config.language in TO_LANGUAGE_CODE.values():
|
||||
language_token = f"<|{generation_config.language}|>"
|
||||
else:
|
||||
is_language_code = len(generation_config.language) == 2
|
||||
raise ValueError(
|
||||
f"Unsupported language: {generation_config.language}. Language should be one of:"
|
||||
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
||||
)
|
||||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
||||
else:
|
||||
forced_decoder_ids.append((1, None)) # automatically detect the language
|
||||
|
||||
if hasattr(generation_config, "task"):
|
||||
if generation_config.task in TASK_IDS:
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
||||
)
|
||||
elif hasattr(generation_config, "task_to_id"):
|
||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
||||
|
||||
if forced_decoder_ids is not None:
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
if prompt_ids is not None:
|
||||
if kwargs.get("decoder_start_token_id") is not None:
|
||||
raise ValueError(
|
||||
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
|
||||
)
|
||||
prompt_ids = prompt_ids.tolist()
|
||||
decoder_start_token_id, *text_prompt_ids = prompt_ids
|
||||
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
|
||||
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
|
||||
text_prompt_ids = text_prompt_ids[-self.config.max_length // 2 - 1 :]
|
||||
# Set the decoder_start_token_id to <|startofprev|>
|
||||
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
|
||||
|
||||
# Update the max generation length to include the prompt
|
||||
specified_max_length = kwargs.pop("max_new_tokens", None) or kwargs.pop("max_length", None)
|
||||
default_max_length = generation_config.max_new_tokens or generation_config.max_length
|
||||
non_prompt_max_length = specified_max_length or default_max_length
|
||||
kwargs["max_new_tokens"] = non_prompt_max_length + len(text_prompt_ids)
|
||||
|
||||
# Reformat the forced_decoder_ids to incorporate the prompt
|
||||
non_prompt_forced_decoder_ids = (
|
||||
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
|
||||
)
|
||||
forced_decoder_ids = [
|
||||
*text_prompt_ids,
|
||||
generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in non_prompt_forced_decoder_ids],
|
||||
]
|
||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
# TODO: Implement `WhisperTimeStampLogitsProcessor`.
|
||||
if generation_config.return_timestamps:
|
||||
# logits_processor = [TFWhisperTimeStampLogitsProcessor(generation_config)]
|
||||
raise ValueError("`TFWhisperForConditionalGeneration` doesn't support returning the timestamps yet.")
|
||||
|
||||
if return_token_timestamps:
|
||||
kwargs["output_attentions"] = True
|
||||
kwargs["return_dict_in_generate"] = True
|
||||
|
||||
if getattr(generation_config, "task", None) == "translate":
|
||||
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
||||
if not hasattr(generation_config, "alignment_heads"):
|
||||
raise ValueError(
|
||||
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||
)
|
||||
|
||||
outputs = super().generate(
|
||||
inputs,
|
||||
generation_config,
|
||||
logits_processor,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
|
||||
|
||||
return outputs
|
||||
|
||||
def serving_output(self, output):
|
||||
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
|
||||
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
|
||||
|
||||
@@ -634,6 +634,48 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
def test_generate_with_prompt_ids_and_task_and_language(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.arange(5)
|
||||
language = "<|de|>"
|
||||
task = "translate"
|
||||
lang_id = 6
|
||||
task_id = 7
|
||||
model.generation_config.__setattr__("lang_to_id", {language: lang_id})
|
||||
model.generation_config.__setattr__("task_to_id", {task: task_id})
|
||||
|
||||
output = model.generate(input_features, max_new_tokens=5, task=task, language=language, prompt_ids=prompt_ids)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
lang_id,
|
||||
task_id,
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
def test_generate_with_prompt_ids_and_forced_decoder_ids(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
input_features = input_dict["input_features"]
|
||||
prompt_ids = np.asarray(range(5))
|
||||
forced_decoder_ids = [(1, 6), (2, 7), (3, 8)]
|
||||
|
||||
output = model.generate(
|
||||
input_features, max_new_tokens=5, forced_decoder_ids=forced_decoder_ids, prompt_ids=prompt_ids
|
||||
)
|
||||
|
||||
expected_output_start = [
|
||||
*prompt_ids.tolist(),
|
||||
model.generation_config.decoder_start_token_id,
|
||||
*[token for _rank, token in forced_decoder_ids],
|
||||
]
|
||||
for row in output.numpy().tolist():
|
||||
self.assertListEqual(row[: len(expected_output_start)], expected_output_start)
|
||||
|
||||
|
||||
def _load_datasamples(num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
@@ -779,24 +821,22 @@ def _test_large_batched_generation(in_queue, out_queue, timeout):
|
||||
generated_ids = np.concatenate([generated_ids_1, generated_ids_2])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50258, 50259, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404],
|
||||
[50258, 50259, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257],
|
||||
[50258, 50259, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904],
|
||||
[50258, 50259, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439]
|
||||
EXPECTED_IDS = [
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
unittest.TestCase().assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
unittest.TestCase().assertEqual(generated_ids.tolist(), EXPECTED_IDS)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad",
|
||||
" Mr. Quilter is the apostle of the middle classes and we are glad to",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user