From f8630c778c9220defecf1e3026d3438108b0baba Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 22 May 2025 10:16:38 +0100 Subject: [PATCH] [Whisper] handle deprecation of `forced_decoder_ids` (#38232) * fix * working saved forced_decoder_ids * docstring * add deprecation message * exception message ordering * circular import comment --- docs/source/en/model_doc/whisper.md | 2 +- .../generation/configuration_utils.py | 13 --- src/transformers/generation/logits_process.py | 28 +++-- src/transformers/generation/utils.py | 6 -- .../models/whisper/generation_whisper.py | 102 +++++++++--------- 5 files changed, 73 insertions(+), 78 deletions(-) diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 3aa6e5c301..4bb51d0ce8 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -95,7 +95,7 @@ transcription[0] ## Notes -- Whisper relies on [`~GenerationMixin.generate`] for inference. +- Whisper relies a custom [`generate`] for inference, make sure to check the docs below. - The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text. ## WhisperConfig diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 353dae2b6f..9bfa5a64d7 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -280,10 +280,6 @@ class GenerationConfig(PushToHubMixin): begin_suppress_tokens (`List[int]`, *optional*): A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[List[int]]`, *optional*): - A list of pairs of integers which indicates a mapping from generation indices to token indices that will be - forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token - of index 123. sequence_bias (`Dict[Tuple[int], float]`, *optional*)): Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the sequence being selected, while negative biases do the opposite. Check @@ -388,12 +384,6 @@ class GenerationConfig(PushToHubMixin): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compilable cache. Please open an issue if you find the need to use this flag. - - > Wild card - - generation_kwargs: - Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not - present in `generate`'s signature will be used in the model forward pass. """ extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits") @@ -449,7 +439,6 @@ class GenerationConfig(PushToHubMixin): self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None) self.suppress_tokens = kwargs.pop("suppress_tokens", None) self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) - self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) self.sequence_bias = kwargs.pop("sequence_bias", None) self.token_healing = kwargs.pop("token_healing", False) self.guidance_scale = kwargs.pop("guidance_scale", None) @@ -494,8 +483,6 @@ class GenerationConfig(PushToHubMixin): # Performance self.compile_config = kwargs.pop("compile_config", None) self.disable_compile = kwargs.pop("disable_compile", False) - # Wild card - self.generation_kwargs = kwargs.pop("generation_kwargs", {}) # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub # interface. diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 34c7ea532c..6e0f0154ab 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,7 +15,7 @@ import inspect import math -from typing import Callable, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -25,6 +25,10 @@ from ..utils import add_start_docstrings from ..utils.logging import get_logger +# TODO (joao): We shouldn't need this, but there would be a circular import +if TYPE_CHECKING: + from ..generation.configuration_utils import GenerationConfig + logger = get_logger(__name__) @@ -1906,8 +1910,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): max_initial_timestamp_index (`int`, *optional*, defaults to 1): Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future. - begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model. - _detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps. + begin_index (`int`): + Token index of the first token that is generated by the model. + _detect_timestamp_from_logprob (`bool`, *optional*): + Whether timestamps can be predicted from logprobs over all timestamps. Examples: ``` python @@ -1940,8 +1946,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): def __init__( self, - generate_config, - begin_index: Optional[int] = None, + generate_config: "GenerationConfig", + begin_index: int, _detect_timestamp_from_logprob: Optional[bool] = None, ): # support for the kwargs self.no_timestamps_token_id = generate_config.no_timestamps_token_id @@ -1954,11 +1960,13 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): if _detect_timestamp_from_logprob is not None else getattr(generate_config, "_detect_timestamp_from_logprob", True) ) - - num_forced_ids = ( - len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 - ) - self.begin_index = begin_index or (num_forced_ids + 1) + self.begin_index = begin_index + if begin_index is None: + raise ValueError( + "`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` " + "must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` " + "was `len(generate_config.forced_decoder_ids)`" + ) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 510f55d824..60859662f6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1246,12 +1246,6 @@ class GenerationMixin: device=device, ) ) - if generation_config.forced_decoder_ids is not None: - # TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT - raise ValueError( - "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " - "in favour of `input_ids` or `decoder_input_ids` respectively.", - ) # TODO (joao): find a strategy to specify the order of the processors processors = self._merge_criteria_processor_list(processors, logits_processor) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index da1d83b2a8..4c29d456bf 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -410,8 +410,7 @@ class WhisperGenerationMixin(GenerationMixin): return_timestamps (`bool`, *optional*): Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. task (`str`, *optional*): - Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` - will be updated accordingly. + Task to use for generation, either "translate" or "transcribe". language (`str` or list of `str`, *optional*): Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For batched generation, a list of language tokens can be passed. You can find all the possible language @@ -1305,8 +1304,9 @@ class WhisperGenerationMixin(GenerationMixin): if not is_shortform: if return_timestamps is False: raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." + "You have passed more than 3000 mel input features (> 30 seconds) which automatically " + "enables long-form generation which requires the model to predict timestamp tokens. Please " + "either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." ) logger.info("Setting `return_timestamps=True` for long-form generation.") @@ -1315,8 +1315,9 @@ class WhisperGenerationMixin(GenerationMixin): if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"): raise ValueError( "You are trying to return timestamps, but the generation config is not properly set. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + "Make sure to initialize the generation config with the correct attributes that are needed such as " + "`no_timestamps_token_id`. For more details on how to generate the approtiate config, refer to " + "https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" ) generation_config.return_timestamps = return_timestamps @@ -1324,8 +1325,9 @@ class WhisperGenerationMixin(GenerationMixin): if hasattr(generation_config, "no_timestamps_token_id"): timestamp_begin = generation_config.no_timestamps_token_id + 1 else: - # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps - # We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop + # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form + # with no timestamps. We set the timestamp begin token larger than the vocab size, such that the + # timestamp condition is never met in the decoding loop timestamp_begin = self.config.vocab_size + 1 return timestamp_begin @@ -1352,8 +1354,8 @@ class WhisperGenerationMixin(GenerationMixin): if not hasattr(generation_config, "lang_to_id"): raise ValueError( "The generation config is outdated and is thus not compatible with the `language` argument " - "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + "to `generate`. Please update the generation config as per the instructions " + "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" ) generation_config.language = language @@ -1361,8 +1363,8 @@ class WhisperGenerationMixin(GenerationMixin): if not hasattr(generation_config, "task_to_id"): raise ValueError( "The generation config is outdated and is thus not compatible with the `task` argument " - "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + "to `generate`. Please update the generation config as per the instructions " + "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" ) generation_config.task = task @@ -1392,51 +1394,53 @@ class WhisperGenerationMixin(GenerationMixin): ) if language_token not in generation_config.lang_to_id: raise ValueError( - f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." - "(You should just add it to the generation config)" + f"{language_token} is not supported by this specific model as it is not in the " + "`generation_config.lang_to_id`. (You should just add it to the generation config)" ) return generation_config.lang_to_id[language_token] task = getattr(generation_config, "task", None) language = getattr(generation_config, "language", None) - - forced_decoder_ids = generation_config.forced_decoder_ids - if forced_decoder_ids is not None: - if language is None and task is None and forced_decoder_ids[0][1] is None: - logger.warning_once( - "Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English." - "This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`." - ) - elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: - forced_decoder_ids = config.forced_decoder_ids - - if forced_decoder_ids is not None and task is not None: - logger.warning_once( - f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}." - ) - forced_decoder_ids = None - elif forced_decoder_ids is not None and language is not None: - logger.warning_once( - f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}." - ) - forced_decoder_ids = None - init_tokens = [generation_config.decoder_start_token_id] - if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: - i = 1 - while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: - init_tokens += [forced_decoder_ids[0][1]] - forced_decoder_ids = forced_decoder_ids[1:] - i += 1 - if len(forced_decoder_ids) > 0: - raise ValueError( - f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.", + # TL;DR we silently ignore `forced_decoder_ids` (old flag) when `task` or `language` (new flags) are set. + # `forced_decoder_ids` is an old generation config attribute that is now deprecated in favor of `task` and + # `language` (see https://github.com/huggingface/transformers/pull/28687). Nevertheless, keep in mind that + # the original checkpoints all contain this attribute, and thus we should maintain backwards compatibility. + if task is None and language is None: + forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None) + # fallback: check the model config for forced_decoder_ids + if forced_decoder_ids is None and getattr(config, "forced_decoder_ids", None) is not None: + forced_decoder_ids = config.forced_decoder_ids + + if forced_decoder_ids is not None: + logger.warning_once( + "Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of " + "the `task` and `language` flags/config options." ) - # from v4.39 the forced decoder ids are always None in favour of decoder input ids - generation_config.forced_decoder_ids = None + if forced_decoder_ids is not None and forced_decoder_ids[0][1] is None: + logger.warning_once( + "Transcription using a multilingual Whisper will default to language detection followed by " + "transcription instead of translation to English. This might be a breaking change for your " + "use case. If you want to instead always translate your audio to English, make sure to pass " + "`language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details." + ) + + if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: + i = 1 + while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[0][1]] + forced_decoder_ids = forced_decoder_ids[1:] + i += 1 + + if len(forced_decoder_ids) > 0: + raise ValueError( + f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow " + f"the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all " + f"indices >= 1 and < {forced_decoder_ids[0][0]}.", + ) is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None) @@ -1444,7 +1448,9 @@ class WhisperGenerationMixin(GenerationMixin): if isinstance(language, (list, tuple)): if any(l is None for l in language): raise TypeError( - "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`." + "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with " + "length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list " + "containing `None`." ) if len(language) != batch_size: raise ValueError(