[Whisper] handle deprecation of forced_decoder_ids (#38232)

* fix

* working saved forced_decoder_ids

* docstring

* add deprecation message

* exception message ordering

* circular import comment
This commit is contained in:
Joao Gante
2025-05-22 10:16:38 +01:00
committed by GitHub
parent aa02a5d902
commit f8630c778c
5 changed files with 73 additions and 78 deletions

View File

@@ -95,7 +95,7 @@ transcription[0]
## Notes
- Whisper relies on [`~GenerationMixin.generate`] for inference.
- Whisper relies a custom [`generate`] for inference, make sure to check the docs below.
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
## WhisperConfig

View File

@@ -280,10 +280,6 @@ class GenerationConfig(PushToHubMixin):
begin_suppress_tokens (`List[int]`, *optional*):
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
sequence being selected, while negative biases do the opposite. Check
@@ -388,12 +384,6 @@ class GenerationConfig(PushToHubMixin):
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
specific criteria are met, including using a compilable cache. Please open an issue if you find the
need to use this flag.
> Wild card
generation_kwargs:
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
present in `generate`'s signature will be used in the model forward pass.
"""
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
@@ -449,7 +439,6 @@ class GenerationConfig(PushToHubMixin):
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.sequence_bias = kwargs.pop("sequence_bias", None)
self.token_healing = kwargs.pop("token_healing", False)
self.guidance_scale = kwargs.pop("guidance_scale", None)
@@ -494,8 +483,6 @@ class GenerationConfig(PushToHubMixin):
# Performance
self.compile_config = kwargs.pop("compile_config", None)
self.disable_compile = kwargs.pop("disable_compile", False)
# Wild card
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
# interface.

View File

@@ -15,7 +15,7 @@
import inspect
import math
from typing import Callable, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -25,6 +25,10 @@ from ..utils import add_start_docstrings
from ..utils.logging import get_logger
# TODO (joao): We shouldn't need this, but there would be a circular import
if TYPE_CHECKING:
from ..generation.configuration_utils import GenerationConfig
logger = get_logger(__name__)
@@ -1906,8 +1910,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
begin_index (`int`):
Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*):
Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
@@ -1940,8 +1946,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
def __init__(
self,
generate_config,
begin_index: Optional[int] = None,
generate_config: "GenerationConfig",
begin_index: int,
_detect_timestamp_from_logprob: Optional[bool] = None,
): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
@@ -1954,11 +1960,13 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
num_forced_ids = (
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
self.begin_index = begin_index
if begin_index is None:
raise ValueError(
"`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
"must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
"was `len(generate_config.forced_decoder_ids)`"
)
self.begin_index = begin_index or (num_forced_ids + 1)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50

View File

@@ -1246,12 +1246,6 @@ class GenerationMixin:
device=device,
)
)
if generation_config.forced_decoder_ids is not None:
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
raise ValueError(
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
"in favour of `input_ids` or `decoder_input_ids` respectively.",
)
# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)

View File

@@ -410,8 +410,7 @@ class WhisperGenerationMixin(GenerationMixin):
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`str`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
Task to use for generation, either "translate" or "transcribe".
language (`str` or list of `str`, *optional*):
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
batched generation, a list of language tokens can be passed. You can find all the possible language
@@ -1305,8 +1304,9 @@ class WhisperGenerationMixin(GenerationMixin):
if not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
"You have passed more than 3000 mel input features (> 30 seconds) which automatically "
"enables long-form generation which requires the model to predict timestamp tokens. Please "
"either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
@@ -1315,8 +1315,9 @@ class WhisperGenerationMixin(GenerationMixin):
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
"Make sure to initialize the generation config with the correct attributes that are needed such as "
"`no_timestamps_token_id`. For more details on how to generate the approtiate config, refer to "
"https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
@@ -1324,8 +1325,9 @@ class WhisperGenerationMixin(GenerationMixin):
if hasattr(generation_config, "no_timestamps_token_id"):
timestamp_begin = generation_config.no_timestamps_token_id + 1
else:
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
# We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form
# with no timestamps. We set the timestamp begin token larger than the vocab size, such that the
# timestamp condition is never met in the decoding loop
timestamp_begin = self.config.vocab_size + 1
return timestamp_begin
@@ -1352,8 +1354,8 @@ class WhisperGenerationMixin(GenerationMixin):
if not hasattr(generation_config, "lang_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `language` argument "
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.language = language
@@ -1361,8 +1363,8 @@ class WhisperGenerationMixin(GenerationMixin):
if not hasattr(generation_config, "task_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `task` argument "
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.task = task
@@ -1392,37 +1394,40 @@ class WhisperGenerationMixin(GenerationMixin):
)
if language_token not in generation_config.lang_to_id:
raise ValueError(
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
"(You should just add it to the generation config)"
f"{language_token} is not supported by this specific model as it is not in the "
"`generation_config.lang_to_id`. (You should just add it to the generation config)"
)
return generation_config.lang_to_id[language_token]
task = getattr(generation_config, "task", None)
language = getattr(generation_config, "language", None)
init_tokens = [generation_config.decoder_start_token_id]
forced_decoder_ids = generation_config.forced_decoder_ids
if forced_decoder_ids is not None:
if language is None and task is None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
)
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
# TL;DR we silently ignore `forced_decoder_ids` (old flag) when `task` or `language` (new flags) are set.
# `forced_decoder_ids` is an old generation config attribute that is now deprecated in favor of `task` and
# `language` (see https://github.com/huggingface/transformers/pull/28687). Nevertheless, keep in mind that
# the original checkpoints all contain this attribute, and thus we should maintain backwards compatibility.
if task is None and language is None:
forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None)
# fallback: check the model config for forced_decoder_ids
if forced_decoder_ids is None and getattr(config, "forced_decoder_ids", None) is not None:
forced_decoder_ids = config.forced_decoder_ids
if forced_decoder_ids is not None and task is not None:
if forced_decoder_ids is not None:
logger.warning_once(
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
"Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of "
"the `task` and `language` flags/config options."
)
if forced_decoder_ids is not None and forced_decoder_ids[0][1] is None:
logger.warning_once(
"Transcription using a multilingual Whisper will default to language detection followed by "
"transcription instead of translation to English. This might be a breaking change for your "
"use case. If you want to instead always translate your audio to English, make sure to pass "
"`language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details."
)
forced_decoder_ids = None
elif forced_decoder_ids is not None and language is not None:
logger.warning_once(
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
)
forced_decoder_ids = None
init_tokens = [generation_config.decoder_start_token_id]
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
i = 1
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
@@ -1432,19 +1437,20 @@ class WhisperGenerationMixin(GenerationMixin):
if len(forced_decoder_ids) > 0:
raise ValueError(
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow "
f"the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all "
f"indices >= 1 and < {forced_decoder_ids[0][0]}.",
)
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
generation_config.forced_decoder_ids = None
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
# Make sure language is a list of strings of the correct length
if isinstance(language, (list, tuple)):
if any(l is None for l in language):
raise TypeError(
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with "
"length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list "
"containing `None`."
)
if len(language) != batch_size:
raise ValueError(