[Whisper] handle deprecation of forced_decoder_ids (#38232)
* fix * working saved forced_decoder_ids * docstring * add deprecation message * exception message ordering * circular import comment
This commit is contained in:
@@ -95,7 +95,7 @@ transcription[0]
|
|||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- Whisper relies on [`~GenerationMixin.generate`] for inference.
|
- Whisper relies a custom [`generate`] for inference, make sure to check the docs below.
|
||||||
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
|
- The [`WhisperProcessor`] can be used for preparing audio and decoding predicted ids back into text.
|
||||||
|
|
||||||
## WhisperConfig
|
## WhisperConfig
|
||||||
|
|||||||
@@ -280,10 +280,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
begin_suppress_tokens (`List[int]`, *optional*):
|
begin_suppress_tokens (`List[int]`, *optional*):
|
||||||
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
A list of tokens that will be suppressed at the beginning of the generation. The `SupressBeginTokens` logit
|
||||||
processor will set their log probs to `-inf` so that they are not sampled.
|
processor will set their log probs to `-inf` so that they are not sampled.
|
||||||
forced_decoder_ids (`List[List[int]]`, *optional*):
|
|
||||||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
|
|
||||||
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
|
|
||||||
of index 123.
|
|
||||||
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
|
sequence_bias (`Dict[Tuple[int], float]`, *optional*)):
|
||||||
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the
|
||||||
sequence being selected, while negative biases do the opposite. Check
|
sequence being selected, while negative biases do the opposite. Check
|
||||||
@@ -388,12 +384,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
|
Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when
|
||||||
specific criteria are met, including using a compilable cache. Please open an issue if you find the
|
specific criteria are met, including using a compilable cache. Please open an issue if you find the
|
||||||
need to use this flag.
|
need to use this flag.
|
||||||
|
|
||||||
> Wild card
|
|
||||||
|
|
||||||
generation_kwargs:
|
|
||||||
Additional generation kwargs will be forwarded to the `generate` function of the model. Kwargs that are not
|
|
||||||
present in `generate`'s signature will be used in the model forward pass.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
|
extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits")
|
||||||
@@ -449,7 +439,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
|
||||||
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
|
||||||
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
|
||||||
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
|
|
||||||
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
self.sequence_bias = kwargs.pop("sequence_bias", None)
|
||||||
self.token_healing = kwargs.pop("token_healing", False)
|
self.token_healing = kwargs.pop("token_healing", False)
|
||||||
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
self.guidance_scale = kwargs.pop("guidance_scale", None)
|
||||||
@@ -494,8 +483,6 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
# Performance
|
# Performance
|
||||||
self.compile_config = kwargs.pop("compile_config", None)
|
self.compile_config = kwargs.pop("compile_config", None)
|
||||||
self.disable_compile = kwargs.pop("disable_compile", False)
|
self.disable_compile = kwargs.pop("disable_compile", False)
|
||||||
# Wild card
|
|
||||||
self.generation_kwargs = kwargs.pop("generation_kwargs", {})
|
|
||||||
|
|
||||||
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
# The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
|
||||||
# interface.
|
# interface.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Iterable, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -25,6 +25,10 @@ from ..utils import add_start_docstrings
|
|||||||
from ..utils.logging import get_logger
|
from ..utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
# TODO (joao): We shouldn't need this, but there would be a circular import
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..generation.configuration_utils import GenerationConfig
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -1906,8 +1910,10 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
||||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
||||||
predicting timestamps that are too far in the future.
|
predicting timestamps that are too far in the future.
|
||||||
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
|
begin_index (`int`):
|
||||||
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
|
Token index of the first token that is generated by the model.
|
||||||
|
_detect_timestamp_from_logprob (`bool`, *optional*):
|
||||||
|
Whether timestamps can be predicted from logprobs over all timestamps.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
``` python
|
``` python
|
||||||
@@ -1940,8 +1946,8 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
generate_config,
|
generate_config: "GenerationConfig",
|
||||||
begin_index: Optional[int] = None,
|
begin_index: int,
|
||||||
_detect_timestamp_from_logprob: Optional[bool] = None,
|
_detect_timestamp_from_logprob: Optional[bool] = None,
|
||||||
): # support for the kwargs
|
): # support for the kwargs
|
||||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||||
@@ -1954,11 +1960,13 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
if _detect_timestamp_from_logprob is not None
|
if _detect_timestamp_from_logprob is not None
|
||||||
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
||||||
)
|
)
|
||||||
|
self.begin_index = begin_index
|
||||||
num_forced_ids = (
|
if begin_index is None:
|
||||||
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
|
raise ValueError(
|
||||||
)
|
"`forced_decoder_ids` is deprecated in favor of `task` and `language` and, as such, `begin_index` "
|
||||||
self.begin_index = begin_index or (num_forced_ids + 1)
|
"must be provided to `WhisperTimeStampLogitsProcessor`. The previous default value of `begin_index` "
|
||||||
|
"was `len(generate_config.forced_decoder_ids)`"
|
||||||
|
)
|
||||||
|
|
||||||
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
||||||
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
|
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
|
||||||
|
|||||||
@@ -1246,12 +1246,6 @@ class GenerationMixin:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if generation_config.forced_decoder_ids is not None:
|
|
||||||
# TODO (sanchit): move this exception to GenerationConfig.validate() when TF & FLAX are aligned with PT
|
|
||||||
raise ValueError(
|
|
||||||
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
|
|
||||||
"in favour of `input_ids` or `decoder_input_ids` respectively.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO (joao): find a strategy to specify the order of the processors
|
# TODO (joao): find a strategy to specify the order of the processors
|
||||||
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
processors = self._merge_criteria_processor_list(processors, logits_processor)
|
||||||
|
|||||||
@@ -410,8 +410,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
return_timestamps (`bool`, *optional*):
|
return_timestamps (`bool`, *optional*):
|
||||||
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
||||||
task (`str`, *optional*):
|
task (`str`, *optional*):
|
||||||
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
|
Task to use for generation, either "translate" or "transcribe".
|
||||||
will be updated accordingly.
|
|
||||||
language (`str` or list of `str`, *optional*):
|
language (`str` or list of `str`, *optional*):
|
||||||
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
|
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. For
|
||||||
batched generation, a list of language tokens can be passed. You can find all the possible language
|
batched generation, a list of language tokens can be passed. You can find all the possible language
|
||||||
@@ -1305,8 +1304,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if not is_shortform:
|
if not is_shortform:
|
||||||
if return_timestamps is False:
|
if return_timestamps is False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
"You have passed more than 3000 mel input features (> 30 seconds) which automatically "
|
||||||
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
"enables long-form generation which requires the model to predict timestamp tokens. Please "
|
||||||
|
"either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
||||||
@@ -1315,8 +1315,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
|
if return_timestamps and not hasattr(generation_config, "no_timestamps_token_id"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You are trying to return timestamps, but the generation config is not properly set. "
|
"You are trying to return timestamps, but the generation config is not properly set. "
|
||||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
"Make sure to initialize the generation config with the correct attributes that are needed such as "
|
||||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
"`no_timestamps_token_id`. For more details on how to generate the approtiate config, refer to "
|
||||||
|
"https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||||
)
|
)
|
||||||
|
|
||||||
generation_config.return_timestamps = return_timestamps
|
generation_config.return_timestamps = return_timestamps
|
||||||
@@ -1324,8 +1325,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if hasattr(generation_config, "no_timestamps_token_id"):
|
if hasattr(generation_config, "no_timestamps_token_id"):
|
||||||
timestamp_begin = generation_config.no_timestamps_token_id + 1
|
timestamp_begin = generation_config.no_timestamps_token_id + 1
|
||||||
else:
|
else:
|
||||||
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form with no timestamps
|
# BC for models missing the `no_timestamps_token_id` in the generation config when generating short-form
|
||||||
# We set the timestamp begin token larger than the vocab size, such that the timestamp condition is never met in the decoding loop
|
# with no timestamps. We set the timestamp begin token larger than the vocab size, such that the
|
||||||
|
# timestamp condition is never met in the decoding loop
|
||||||
timestamp_begin = self.config.vocab_size + 1
|
timestamp_begin = self.config.vocab_size + 1
|
||||||
|
|
||||||
return timestamp_begin
|
return timestamp_begin
|
||||||
@@ -1352,8 +1354,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if not hasattr(generation_config, "lang_to_id"):
|
if not hasattr(generation_config, "lang_to_id"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The generation config is outdated and is thus not compatible with the `language` argument "
|
"The generation config is outdated and is thus not compatible with the `language` argument "
|
||||||
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
|
"to `generate`. Please update the generation config as per the instructions "
|
||||||
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
||||||
)
|
)
|
||||||
generation_config.language = language
|
generation_config.language = language
|
||||||
|
|
||||||
@@ -1361,8 +1363,8 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if not hasattr(generation_config, "task_to_id"):
|
if not hasattr(generation_config, "task_to_id"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The generation config is outdated and is thus not compatible with the `task` argument "
|
"The generation config is outdated and is thus not compatible with the `task` argument "
|
||||||
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
|
"to `generate`. Please update the generation config as per the instructions "
|
||||||
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
||||||
)
|
)
|
||||||
generation_config.task = task
|
generation_config.task = task
|
||||||
|
|
||||||
@@ -1392,51 +1394,53 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
)
|
)
|
||||||
if language_token not in generation_config.lang_to_id:
|
if language_token not in generation_config.lang_to_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
f"{language_token} is not supported by this specific model as it is not in the "
|
||||||
"(You should just add it to the generation config)"
|
"`generation_config.lang_to_id`. (You should just add it to the generation config)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return generation_config.lang_to_id[language_token]
|
return generation_config.lang_to_id[language_token]
|
||||||
|
|
||||||
task = getattr(generation_config, "task", None)
|
task = getattr(generation_config, "task", None)
|
||||||
language = getattr(generation_config, "language", None)
|
language = getattr(generation_config, "language", None)
|
||||||
|
|
||||||
forced_decoder_ids = generation_config.forced_decoder_ids
|
|
||||||
if forced_decoder_ids is not None:
|
|
||||||
if language is None and task is None and forced_decoder_ids[0][1] is None:
|
|
||||||
logger.warning_once(
|
|
||||||
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
|
|
||||||
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
|
|
||||||
)
|
|
||||||
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
|
|
||||||
forced_decoder_ids = config.forced_decoder_ids
|
|
||||||
|
|
||||||
if forced_decoder_ids is not None and task is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
f"You have passed task={task}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of task={task}."
|
|
||||||
)
|
|
||||||
forced_decoder_ids = None
|
|
||||||
elif forced_decoder_ids is not None and language is not None:
|
|
||||||
logger.warning_once(
|
|
||||||
f"You have passed language={language}, but also have set `forced_decoder_ids` to {forced_decoder_ids} which creates a conflict. `forced_decoder_ids` will be ignored in favor of language={language}."
|
|
||||||
)
|
|
||||||
forced_decoder_ids = None
|
|
||||||
|
|
||||||
init_tokens = [generation_config.decoder_start_token_id]
|
init_tokens = [generation_config.decoder_start_token_id]
|
||||||
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
|
||||||
i = 1
|
|
||||||
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
|
||||||
init_tokens += [forced_decoder_ids[0][1]]
|
|
||||||
forced_decoder_ids = forced_decoder_ids[1:]
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
if len(forced_decoder_ids) > 0:
|
# TL;DR we silently ignore `forced_decoder_ids` (old flag) when `task` or `language` (new flags) are set.
|
||||||
raise ValueError(
|
# `forced_decoder_ids` is an old generation config attribute that is now deprecated in favor of `task` and
|
||||||
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all indices >= 1 and < {forced_decoder_ids[0][0]}.",
|
# `language` (see https://github.com/huggingface/transformers/pull/28687). Nevertheless, keep in mind that
|
||||||
|
# the original checkpoints all contain this attribute, and thus we should maintain backwards compatibility.
|
||||||
|
if task is None and language is None:
|
||||||
|
forced_decoder_ids = getattr(generation_config, "forced_decoder_ids", None)
|
||||||
|
# fallback: check the model config for forced_decoder_ids
|
||||||
|
if forced_decoder_ids is None and getattr(config, "forced_decoder_ids", None) is not None:
|
||||||
|
forced_decoder_ids = config.forced_decoder_ids
|
||||||
|
|
||||||
|
if forced_decoder_ids is not None:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of "
|
||||||
|
"the `task` and `language` flags/config options."
|
||||||
)
|
)
|
||||||
|
|
||||||
# from v4.39 the forced decoder ids are always None in favour of decoder input ids
|
if forced_decoder_ids is not None and forced_decoder_ids[0][1] is None:
|
||||||
generation_config.forced_decoder_ids = None
|
logger.warning_once(
|
||||||
|
"Transcription using a multilingual Whisper will default to language detection followed by "
|
||||||
|
"transcription instead of translation to English. This might be a breaking change for your "
|
||||||
|
"use case. If you want to instead always translate your audio to English, make sure to pass "
|
||||||
|
"`language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details."
|
||||||
|
)
|
||||||
|
|
||||||
|
if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1:
|
||||||
|
i = 1
|
||||||
|
while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i:
|
||||||
|
init_tokens += [forced_decoder_ids[0][1]]
|
||||||
|
forced_decoder_ids = forced_decoder_ids[1:]
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if len(forced_decoder_ids) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"You are using token ids in `forced_decoder_ids` that do not seem to correctly follow "
|
||||||
|
f"the prompt pattern of Whisper. Make sure that {forced_decoder_ids} has an entry for all "
|
||||||
|
f"indices >= 1 and < {forced_decoder_ids[0][0]}.",
|
||||||
|
)
|
||||||
|
|
||||||
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
is_lang_id_undefined = len(init_tokens) <= 1 or (len(init_tokens) > 1 and init_tokens[1] is None)
|
||||||
|
|
||||||
@@ -1444,7 +1448,9 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if isinstance(language, (list, tuple)):
|
if isinstance(language, (list, tuple)):
|
||||||
if any(l is None for l in language):
|
if any(l is None for l in language):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list containing `None`."
|
"Expected `language` to be `None`, a single string (e.g. `'en'`), or a list of strings with "
|
||||||
|
"length equal to the batch size (e.g. `('en', 'fr')` for a batch size of 2). Got a list "
|
||||||
|
"containing `None`."
|
||||||
)
|
)
|
||||||
if len(language) != batch_size:
|
if len(language) != batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Reference in New Issue
Block a user