[Whisper] handle deprecation of forced_decoder_ids (#38232)

* fix

* working saved forced_decoder_ids

* docstring

* add deprecation message

* exception message ordering

* circular import comment
This commit is contained in:
Joao Gante
2025-05-22 10:16:38 +01:00
committed by GitHub
parent aa02a5d902
commit f8630c778c
5 changed files with 73 additions and 78 deletions

View File

@@ -95,7 +95,7 @@ transcription[0]
## Notes ## Notes
- Whisper relies on [`~GenerationMixin.generate`] for inference. - Whisper relies a custom [`generate`] for inference, make sure to check the docs below.
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text. - The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
## WhisperConfig ## WhisperConfig

View File

@@ -280,10 +280,6 @@ class GenerationConfig(PushToHubMixin):
begin_suppress_tokens (`List[int]`, *optional*): begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled. processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
sequence_bias (`Dict[Tuple[int], float]`, *optional*)): sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check sequence being selected, while negative biases do the opposite. Check
@@ -388,12 +384,6 @@ class GenerationConfig(PushToHubMixin):
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
specific criteria are met, including using a compilable cache. Please open an issue if you find the specific criteria are met, including using a compilable cache. Please open an issue if you find the
need to use this flag. need to use this flag.
> Wild card
generation_kwargs:
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
present in `generate`'s signature will be used in the model forward pass.
""" """
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits") extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
@@ -449,7 +439,6 @@ class GenerationConfig(PushToHubMixin):
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None) self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
self.suppress_tokens = kwargs.pop("suppress_tokens", None) self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None) self.sequence_bias = kwargs.pop("sequence_bias", None)
self.token_healing = kwargs.pop("token_healing", False) self.token_healing = kwargs.pop("token_healing", False)
self.guidance_scale = kwargs.pop("guidance_scale", None) self.guidance_scale = kwargs.pop("guidance_scale", None)
@@ -494,8 +483,6 @@ class GenerationConfig(PushToHubMixin):
# Performance # Performance
self.compile_config = kwargs.pop("compile_config", None) self.compile_config = kwargs.pop("compile_config", None)
self.disable_compile = kwargs.pop("disable_compile", False) self.disable_compile = kwargs.pop("disable_compile", False)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
# interface. # interface.

View File

@@ -15,7 +15,7 @@
import inspect import inspect
import math import math
from typing import Callable, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@@ -25,6 +25,10 @@ from ..utils import add_start_docstrings
from ..utils.logging import get_logger from ..utils.logging import get_logger
# TODO (joao): We shouldn't need this, but there would be a circular import
if TYPE_CHECKING:
from ..generation.configuration_utils import GenerationConfig
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -1906,8 +1910,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1): max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future. predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model. begin_index (`int`):
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps. Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*):
Whether timestamps can be predicted from logprobs over all timestamps.
Examples: Examples:
``` python ``` python
@@ -1940,8 +1946,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
def __init__( def __init__(
self, self,
generate_config, generate_config: "GenerationConfig",
begin_index: Optional[int] = None, begin_index: int,
_detect_timestamp_from_logprob: Optional[bool] = None, _detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs ): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id
@@ -1954,11 +1960,13 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
if _detect_timestamp_from_logprob is not None if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True) else getattr(generate_config, "_detect_timestamp_from_logprob", True)
) )
self.begin_index = begin_index
num_forced_ids = ( if begin_index is None:
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 raise ValueError(
"`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
"must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
"was `len(generate_config.forced_decoder_ids)`"
) )
self.begin_index = begin_index or (num_forced_ids + 1)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50

View File

@@ -1246,12 +1246,6 @@ class GenerationMixin:
device=device, device=device,
) )
) )
if generation_config.forced_decoder_ids is not None:
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
raise ValueError(
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
"in favour of `input_ids` or `decoder_input_ids` respectively.",
)
# TODO (joao): find a strategy to specify the order of the processors # TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor) processors = self._merge_criteria_processor_list(processors, logits_processor)

View File

@@ -410,8 +410,7 @@ class WhisperGenerationMixin(GenerationMixin):
return_timestamps (`bool`, *optional*): return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`str`, *optional*): task (`str`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` Task to use for generation, either "translate" or "transcribe".
will be updated accordingly.
language (`str` or list of `str`, *optional*): language (`str` or list of `str`, *optional*):
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
batched generation, a list of language tokens can be passed. You can find all the possible language batched generation, a list of language tokens can be passed. You can find all the possible language
@@ -1305,8 +1304,9 @@ class WhisperGenerationMixin(GenerationMixin):
if not is_shortform: if not is_shortform:
if return_timestamps is False: if return_timestamps is False:
raise ValueError( raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " "You have passed more than 3000 mel input features (> 30 seconds) which automatically "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." "enables long-form generation which requires the model to predict timestamp tokens. Please "
"either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
) )
logger.info("Setting `return_timestamps=True` for long-form generation.") logger.info("Setting `return_timestamps=True` for long-form generation.")
@@ -1315,8 +1315,9 @@ class WhisperGenerationMixin(GenerationMixin):
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"): if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError( raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. " "You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " "Make sure to initialize the generation config with the correct attributes that are needed such as "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" "`no_timestamps_token_id`. For more details on how to generate the approtiate config, refer to "
"https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
) )
generation_config.return_timestamps = return_timestamps generation_config.return_timestamps = return_timestamps
@@ -1324,8 +1325,9 @@ class WhisperGenerationMixin(GenerationMixin):
if hasattr(generation_config, "no_timestamps_token_id"): if hasattr(generation_config, "no_timestamps_token_id"):
timestamp_begin = generation_config.no_timestamps_token_id + 1 timestamp_begin = generation_config.no_timestamps_token_id + 1
else: else:
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps # BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form
# We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop # with no timestamps. We set the timestamp begin token larger than the vocab size, such that the
# timestamp condition is never met in the decoding loop
timestamp_begin = self.config.vocab_size + 1 timestamp_begin = self.config.vocab_size + 1
return timestamp_begin return timestamp_begin
@@ -1352,8 +1354,8 @@ class WhisperGenerationMixin(GenerationMixin):
if not hasattr(generation_config, "lang_to_id"): if not hasattr(generation_config, "lang_to_id"):
raise ValueError( raise ValueError(
"The generation config is outdated and is thus not compatible with the `language` argument " "The generation config is outdated and is thus not compatible with the `language` argument "
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " "to `generate`. Please update the generation config as per the instructions "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
) )
generation_config.language = language generation_config.language = language
@@ -1361,8 +1363,8 @@ class WhisperGenerationMixin(GenerationMixin):
if not hasattr(generation_config, "task_to_id"): if not hasattr(generation_config, "task_to_id"):
raise ValueError( raise ValueError(
"The generation config is outdated and is thus not compatible with the `task` argument " "The generation config is outdated and is thus not compatible with the `task` argument "
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " "to `generate`. Please update the generation config as per the instructions "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
) )
generation_config.task = task generation_config.task = task
@@ -1392,37 +1394,40 @@ class WhisperGenerationMixin(GenerationMixin):
) )
if language_token not in generation_config.lang_to_id: if language_token not in generation_config.lang_to_id:
raise ValueError( raise ValueError(
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." f"{language_token} is not supported by this specific model as it is not in the "
"(You should just add it to the generation config)" "`generation_config.lang_to_id`. (You should just add it to the generation config)"
) )
return generation_config.lang_to_id[language_token] return generation_config.lang_to_id[language_token]
task = getattr(generation_config, "task", None) task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None) language = getattr(generation_config, "language", None)
init_tokens = [generation_config.decoder_start_token_id]
forced_decoder_ids = generation_config.forced_decoder_ids # TL;DR we silently ignore `forced_decoder_ids` (old flag) when `task` or `language` (new flags) are set.
if forced_decoder_ids is not None: # `forced_decoder_ids` is an old generation config attribute that is now deprecated in favor of `task` and
if language is None and task is None and forced_decoder_ids[0][1] is None: # `language` (see https://github.com/huggingface/transformers/pull/28687). Nevertheless, keep in mind that
logger.warning_once( # the original checkpoints all contain this attribute, and thus we should maintain backwards compatibility.
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English." if task is None and language is None:
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`." forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None)
) # fallback: check the model config for forced_decoder_ids
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: if forced_decoder_ids is None and getattr(config, "forced_decoder_ids", None) is not None:
forced_decoder_ids = config.forced_decoder_ids forced_decoder_ids = config.forced_decoder_ids
if forced_decoder_ids is not None and task is not None: if forced_decoder_ids is not None:
logger.warning_once( logger.warning_once(
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}." "Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of "
"the `task` and `language` flags/config options."
)
if forced_decoder_ids is not None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Transcription using a multilingual Whisper will default to language detection followed by "
"transcription instead of translation to English. This might be a breaking change for your "
"use case. If you want to instead always translate your audio to English, make sure to pass "
"`language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details."
) )
forced_decoder_ids = None
elif forced_decoder_ids is not None and language is not None:
logger.warning_once(
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
)
forced_decoder_ids = None
init_tokens = [generation_config.decoder_start_token_id]
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
i = 1 i = 1
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
@@ -1432,19 +1437,20 @@ class WhisperGenerationMixin(GenerationMixin):
if len(forced_decoder_ids) > 0: if len(forced_decoder_ids) > 0:
raise ValueError( raise ValueError(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.", f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow "
f"the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all "
f"indices >= 1 and < {forced_decoder_ids[0][0]}.",
) )
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
generation_config.forced_decoder_ids = None
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None) is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
# Make sure language is a list of strings of the correct length # Make sure language is a list of strings of the correct length
if isinstance(language, (list, tuple)): if isinstance(language, (list, tuple)):
if any(l is None for l in language): if any(l is None for l in language):
raise TypeError( raise TypeError(
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`." "Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with "
"length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list "
"containing `None`."
) )
if len(language) != batch_size: if len(language) != batch_size:
raise ValueError( raise ValueError(