[Whisper] Finalize batched SOTA long-form generation (#27658)

* finalize

* make fix copies whisper

* [Tests] Make sure that we don't run tests mulitple times

* Update src/transformers/models/whisper/modeling_whisper.py

* [Tests] Make sure that we don't run tests mulitple times

* fix more

* improve

* improve

* improve further

* improve more

* improve

* fix more

* git commit and git push

* fix more

* fix more

* fix more

* New try

* Fix more whisper stuff

* Improve

* correct more

* correct more

* correct more

* Fix some tests

* Add more tests

* correct more

* correct more

* correct more

* push

* correct more

* Fix more

* Better

* without dec mask

* correct more

* clean

* save intermediate

* Fix more

* Fix VAD for large-v2

* Save new

* Correct more

* make cleaner

* correct tests

* correct src

* Finish

* Fix more

* Fix more

* finish

* Fix edge cases

* fix return_dict_in_generate

* fix all tests

* make style

* add docstrings

* add docstrings

* Fix logit processor

* make style

* fix pipeline test

* fix more style

* Apply suggestions from code review

* apply feedback Sanchit

* correct more

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* correct more

* correct more

* correct more

* Fix staticmethod

* correct more

* fix

* fix slow tests

* make style

* fix tokenizer test

* fix tokenizer test

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* finish

* finish

* revert kwargs change

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2024-01-19 14:04:17 +02:00
committed by GitHub
parent d4fc1eb498
commit 690fe73f20
8 changed files with 1825 additions and 852 deletions

View File

@@ -95,6 +95,7 @@ class LogitsProcessorList(list):
scores = processor(input_ids, scores, **kwargs) scores = processor(input_ids, scores, **kwargs)
else: else:
scores = processor(input_ids, scores) scores = processor(input_ids, scores)
return scores return scores
@@ -1657,6 +1658,9 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
self.begin_suppress_tokens = list(begin_suppress_tokens) self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index self.begin_index = begin_index
def set_begin_index(self, begin_index):
self.begin_index = begin_index
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index: if input_ids.shape[1] == self.begin_index:
@@ -1778,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1): max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future. predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps. _detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples: Examples:
@@ -1810,11 +1815,11 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
""" """
def __init__( def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs ): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
# this variable is mostly just used for testing # this variable is mostly just used for testing
self._detect_timestamp_from_logprob = ( self._detect_timestamp_from_logprob = (
@@ -1823,10 +1828,17 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
else getattr(generate_config, "_detect_timestamp_from_logprob", True) else getattr(generate_config, "_detect_timestamp_from_logprob", True)
) )
self.begin_index = ( num_forced_ids = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1 len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
) )
self.begin_index = begin_index or (num_forced_ids + 1)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
# self.max_initial_timestamp_index = 50
def set_begin_index(self, begin_index):
self.begin_index = begin_index
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@@ -1878,6 +1890,60 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
return scores return scores
class WhisperNoSpeechDetection(LogitsProcessor):
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""
def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
self.no_speech_token = no_speech_token
# offset between <start-of-transcription> token, <SOT>, in paper and first generated token
# is equal to the position of the first generated token index
self.start_of_trans_offset = begin_index
# `self.begin_index` is a running value that is changed on the fly
self.begin_index = begin_index
self._no_speech_prob = [0.0]
self.is_scores_logprobs = scores_is_logprobs
# overwritten dynamically
self.model = None
self.inputs = None
def set_model(self, model):
self.model = model
def set_inputs(self, inputs):
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
self.inputs["input_features"] = self.inputs.pop("inputs")
@property
def no_speech_prob(self):
return self._no_speech_prob
def set_begin_index(self, begin_index):
self.begin_index = begin_index
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
logits = self.model(**self.inputs).logits
no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
else:
no_speech_scores = scores
if self.is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)
self._no_speech_prob = probs[:, self.no_speech_token]
return scores
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r""" r"""
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,

View File

@@ -518,6 +518,8 @@ class GenerationMixin:
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
pass pass
elif self.config.model_type in ["whisper"]:
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided) # decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():

File diff suppressed because it is too large Load Diff

View File

@@ -13,10 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Whisper model.""" """ PyTorch Whisper model."""
import copy
import math import math
import warnings
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
@@ -27,7 +24,6 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
@@ -47,7 +43,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_whisper import WhisperConfig from .configuration_whisper import WhisperConfig
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE from .generation_whisper import WhisperGenerationMixin
if is_flash_attn_2_available(): if is_flash_attn_2_available():
@@ -231,87 +227,15 @@ def _compute_mask_indices(
return spec_aug_mask return spec_aug_mask
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
"""
Applies a median filter of width `filter_width` along the last dimension of the input.
The `inputs` tensor is assumed to be 3- or 4-dimensional.
"""
if filter_width <= 0 or filter_width % 2 != 1:
raise ValueError("`filter_width` should be an odd number")
pad_width = filter_width // 2
if inputs.shape[-1] <= pad_width:
return inputs
# Pad the left and right edges.
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
return result
def _dynamic_time_warping(matrix: np.ndarray):
"""
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
token-level timestamps.
"""
output_length, input_length = matrix.shape
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, input_length + 1):
for i in range(1, output_length + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = matrix[i - 1, j - 1] + c
trace[i, j] = t
# backtrace
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
text_indices = []
time_indices = []
while i > 0 or j > 0:
text_indices.append(i - 1)
time_indices.append(j - 1)
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise RuntimeError(
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
)
text_indices = np.array(text_indices)[::-1]
time_indices = np.array(time_indices)[::-1]
return text_indices, time_indices
class WhisperPositionalEmbedding(nn.Embedding): class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
def forward(self, input_ids, past_key_values_length=0): def forward(self, input_ids, past_key_values_length=0, position_ids=None):
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] if position_ids is None:
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
else:
return self.weight[position_ids]
class WhisperAttention(nn.Module): class WhisperAttention(nn.Module):
@@ -1358,6 +1282,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
cross_attn_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
position_ids=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
@@ -1461,9 +1386,13 @@ class WhisperDecoder(WhisperPreTrainedModel):
# embed positions # embed positions
if input_ids is not None: if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) positions = self.embed_positions(
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
)
else: else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) positions = self.embed_positions(
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -1645,6 +1574,7 @@ class WhisperModel(WhisperPreTrainedModel):
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
@@ -1703,6 +1633,7 @@ class WhisperModel(WhisperPreTrainedModel):
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
@@ -1728,7 +1659,7 @@ class WhisperModel(WhisperPreTrainedModel):
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
WHISPER_START_DOCSTRING, WHISPER_START_DOCSTRING,
) )
class WhisperForConditionalGeneration(WhisperPreTrainedModel): class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
base_model_prefix = "model" base_model_prefix = "model"
_tied_weights_keys = ["proj_out.weight"] _tied_weights_keys = ["proj_out.weight"]
@@ -1776,6 +1707,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
@@ -1830,6 +1762,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
@@ -1860,647 +1793,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
def generate(
self,
input_features: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=False,
return_timestamps=None,
task=None,
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
num_segment_frames: Optional[int] = None,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
attention_mask: Optional[torch.Tensor] = None,
time_precision: int = 0.02,
return_dict_in_generate: Optional[bool] = None,
**kwargs,
):
"""
Transcribes or translates passed mel input features to a sequence of token ids.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
For an overview of generation strategies and code examples, check out the [following
guide](./generation_strategies).
</Tip>
Parameters:
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
logits_processor (`LogitsProcessorList`, *optional*):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
stopping_criteria (`StoppingCriteriaList`, *optional*):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
If provided, this function constraints the beam search to allowed tokens only at each step. If not
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
return_timestamps (`bool`, *optional*):
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
task (`str`, *optional*):
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
will be updated accordingly.
language (`str`, *optional*):
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
is_multilingual (`bool`, *optional*):
Whether or not the model is multilingual.
prompt_ids (`torch.Tensor`, *optional*):
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
return_token_timestamps (`bool`, *optional*):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
words.
return_segments (`bool`, *optional*, defaults to `False`):
Whether to additionally return a list of all segments. Note that this option can only be enabled
when doing long-form transcription.
attention_mask (`torch.Tensor`, *optional*):
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
`return_segments` is set True. In this case the generation outputs of each segment is added to each
segment.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GenerateEncoderDecoderOutput`],
- [`~generation.GenerateBeamEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
Example:
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset, Audio
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> model.cuda()
>>> # load audios > 30 seconds
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
>>> # resample to 16kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
>>> # take first 8 audios and retrieve array
>>> audio = ds[:8]["audio"]
>>> audio = [x["array"] for x in audio]
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
>>> inputs = inputs.to("cuda", torch.float32)
>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> transcription[0]
' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!'
```
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
if "inputs" in kwargs:
input_features = kwargs.pop("inputs")
warnings.warn(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
if generation_config is None:
generation_config = copy.deepcopy(self.generation_config)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else generation_config.return_dict_in_generate
)
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
if num_segment_frames is None:
num_segment_frames = input_stride * self.config.max_source_positions
# 1. Check whether we're in shortform or longform mode
if input_features is not None:
total_input_frames = input_features.shape[-1]
elif "encoder_outputs" in kwargs:
encoder_outputs_shape = (
kwargs["encoder_outputs"][0].shape
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
total_input_frames = encoder_outputs_shape[1] * input_stride
else:
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
is_shortform = total_input_frames <= num_segment_frames
# 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
if return_timestamps is True:
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
elif not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the generation config to have `no_timestamps_token_id` correctly. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
"or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
generation_config.return_timestamps = True
else:
generation_config.return_timestamps = False
# 3. Make sure to correctly set language-related parameters
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
"to `generate`. Please update the generation config as per the instructions "
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.is_multilingual = is_multilingual
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
if task is not None or language is not None:
raise ValueError(
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
)
if language is not None:
if not hasattr(generation_config, "lang_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `language` argument "
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
language = language.lower()
generation_config.language = language
if task is not None:
if not hasattr(generation_config, "task_to_id"):
raise ValueError(
"The generation config is outdated and is thus not compatible with the `task` argument "
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
)
generation_config.task = task
# 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps`
forced_decoder_ids = None
# Legacy code for backward compatibility
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
elif (
hasattr(self.generation_config, "forced_decoder_ids")
and self.generation_config.forced_decoder_ids is not None
):
forced_decoder_ids = self.generation_config.forced_decoder_ids
else:
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
forced_decoder_ids = []
if hasattr(generation_config, "language"):
if generation_config.language in generation_config.lang_to_id.keys():
language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
elif generation_config.language in TO_LANGUAGE_CODE.values():
language_token = f"<|{generation_config.language}|>"
else:
is_language_code = len(generation_config.language) == 2
raise ValueError(
f"Unsupported language: {generation_config.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
)
if language_token not in generation_config.lang_to_id:
raise ValueError(
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
"(You should just add it to the generation config)"
)
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else:
forced_decoder_ids.append((1, None)) # automatically detect the language
if hasattr(generation_config, "task"):
if generation_config.task in TASK_IDS:
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
else:
raise ValueError(
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
)
elif hasattr(generation_config, "task_to_id"):
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
if forced_decoder_ids is not None:
generation_config.forced_decoder_ids = forced_decoder_ids
if prompt_ids is not None:
if kwargs.get("decoder_start_token_id") is not None:
raise ValueError(
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
)
prompt_ids = prompt_ids.tolist()
decoder_start_token_id, *text_prompt_ids = prompt_ids
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :]
# Set the decoder_start_token_id to <|startofprev|>
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
# If the user passes `max_new_tokens`, increase its number to account for the prompt
if kwargs.get("max_new_tokens", None) is not None:
kwargs["max_new_tokens"] += len(text_prompt_ids)
if kwargs["max_new_tokens"] >= self.config.max_target_positions:
raise ValueError(
f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` "
f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced "
f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the "
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {self.config.max_target_positions}."
)
# Reformat the forced_decoder_ids to incorporate the prompt
non_prompt_forced_decoder_ids = (
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
)
forced_decoder_ids = [
*text_prompt_ids,
generation_config.decoder_start_token_id,
*[token for _rank, token in non_prompt_forced_decoder_ids],
]
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids
if return_token_timestamps:
kwargs["output_attentions"] = True
return_dict_in_generate = True
kwargs["output_scores"] = True
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
if not hasattr(generation_config, "alignment_heads"):
raise ValueError(
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
)
if kwargs.get("num_frames") is not None:
generation_config.num_frames = kwargs.pop("num_frames")
if generation_config.return_timestamps is True:
last_forced_decoder_ids = (
generation_config.forced_decoder_ids[-1][-1]
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids
else None
)
if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id:
# remove no_timestamp to be forcefully generated if we want to return timestamps
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
forced_decoder_ids = generation_config.forced_decoder_ids[:-1]
# Make sure that if list is empty we set it to None
generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids
timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
logits_processor = (
timestamp_processor if logits_processor is None else timestamp_processor + logits_processor
)
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
if is_shortform:
outputs = super().generate(
input_features,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
outputs["token_timestamps"] = self._extract_token_timestamps(
outputs, generation_config.alignment_heads, num_frames=num_frames
)
return outputs
# 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated
# timestamp tokens
# 6.1 Set running parameters for while loop
if not return_segments and return_dict_in_generate:
raise ValueError(
"Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`"
)
# if input is longer than 30 seconds we default to long-form generation
timestamp_begin = self.generation_config.no_timestamps_token_id + 1
# input stride is mel frames per encoder output vector which is the product of all conv strides
batch_size = input_features.shape[0]
if batch_size > 1 and attention_mask is None:
raise ValueError(
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
)
elif batch_size > 1:
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
seek = torch.zeros((batch_size,), dtype=torch.long)
else:
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames
seek = torch.zeros((1,), dtype=torch.long)
current_segments = [[] for _ in range(batch_size)]
cur_to_prev_index_map = list(range(batch_size))
# batch size can decrease during the run
cur_bsz = prev_bsz = batch_size
# 6.2 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
prev_bsz = cur_bsz
# 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
# to know which original audio is being decoded
new_cur_to_prev_index_map = []
for i in range(prev_bsz):
prev_i = cur_to_prev_index_map[i]
if seek[prev_i] >= max_frames[prev_i]:
cut_index = i + (cur_bsz - prev_bsz)
cur_bsz -= 1
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
else:
# cut out index that goes away
new_cur_to_prev_index_map.append(prev_i)
# 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
cur_to_prev_index_map = new_cur_to_prev_index_map
time_offset = seek * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.5 Make sure that all inputs are padded to the same input length
segment_input = []
for i in range(cur_bsz):
prev_i = cur_to_prev_index_map[i]
segment_input_slice = input_features[
i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]
]
if segment_input_slice.shape[-1] < num_segment_frames:
# pad to 3000 if necessary
segment_input_slice = F.pad(
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
)
segment_input.append(segment_input_slice)
segment_input = torch.cat(segment_input, dim=0)
# 6.6 Batch generate current chunk
seek_outputs = super().generate(
segment_input,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
if return_dict_in_generate:
seek_sequences = seek_outputs["sequences"]
seek_outputs = [
{k: v[i] for k, v in seek_outputs.items()}
for i in range(next(iter(seek_outputs.values())).size(0))
]
else:
seek_sequences = seek_outputs
# 6.7 Loop over each decoded audio individually as each decoding can be of a different length
for i, seek_sequence in enumerate(seek_sequences):
prev_i = cur_to_prev_index_map[i]
# make sure we cut a predicted EOS token if we are not finished with the generation yet
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
# remove all padding tokens
if seek_sequence[-1] == self.generation_config.pad_token_id:
num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
segments, segment_offset = self._retrieve_segment(
seek_sequence=seek_sequence,
seek_outputs=seek_outputs,
time_offset=time_offset,
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
cur_bsz=cur_bsz,
time_precision=time_precision,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
)
current_segments[prev_i] += segments
seek[prev_i] += segment_offset
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
sequences = []
max_total_length = 0
for current_segment_list in current_segments:
sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1))
max_total_length = max(max_total_length, len(sequences[-1]))
for i in range(batch_size):
sequences[i] = F.pad(
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
)
sequences = torch.stack(sequences, dim=0)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": current_segments}
return sequences
@staticmethod
def _retrieve_segment(
seek_sequence,
seek_outputs,
time_offset,
timestamp_begin,
seek_num_frames,
cur_bsz,
time_precision,
input_stride,
prev_idx,
idx,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
if len(timestamp_segment_indices) > 0:
# if the output contains two consecutive timestamp tokens
slices = timestamp_segment_indices.tolist()
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
segment_offset = seek_num_frames[prev_idx]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [
{
"start": time_offset[prev_idx],
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
"result": seek_outputs[idx],
}
]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
@@ -2508,8 +1800,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
attention_mask=None, attention_mask=None,
decoder_attention_mask=None,
**kwargs, **kwargs,
): ):
decoder_position_ids = None
if decoder_attention_mask is not None:
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] past_length = past_key_values[0][0].shape[2]
@@ -2522,12 +1819,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
return { return {
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"use_cache": use_cache, "use_cache": use_cache,
"decoder_attention_mask": None, "decoder_attention_mask": decoder_attention_mask,
"decoder_position_ids": decoder_position_ids,
} }
@staticmethod @staticmethod
@@ -2539,99 +1840,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
) )
return reordered_past return reordered_past
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
"""
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
cross-attentions will be cropped before applying DTW.
Returns:
tensor containing the timestamps in seconds for each predicted token
"""
# Create a list with `decoder_layers` elements, each a tensor of shape
# (batch size, attention_heads, output length, input length).
cross_attentions = []
for i in range(self.config.decoder_layers):
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
# Select specific cross-attention layers and heads. This is a tensor
# of shape (batch size, num selected, output length, input length).
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
weights = weights.permute([1, 0, 2, 3])
if "beam_indices" in generate_outputs:
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
# since the beam search strategy chooses the most probable sequences at the end of the search.
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
weights = weights[:, :, :weight_length]
# If beam index is still -1, it means that the associated token id is EOS
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
beam_indices = generate_outputs.beam_indices[:, :weight_length]
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
# Select the cross attention from the right beam for each output sequences
weights = torch.stack(
[
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
for i in range(beam_indices.shape[1])
],
dim=2,
)
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
batch_size = timestamps.shape[0]
if num_frames is not None:
# two cases:
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
# we're using np.unique because num_frames can be int/list/tuple
if len(np.unique(num_frames)) == 1:
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]
weights = weights[..., : num_frames // 2]
else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
num_frames = np.repeat(num_frames, repeat_time)
if num_frames is None or isinstance(num_frames, int):
# Normalize and smoothen the weights.
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
mean = torch.mean(weights, dim=-2, keepdim=True)
weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width)
# Average the different cross-attention heads.
weights = weights.mean(dim=1)
# Perform dynamic time warping on each element of the batch.
for batch_idx in range(batch_size):
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)):
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
# Normalize and smoothen the weights.
std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
mean = torch.mean(matrix, dim=-2, keepdim=True)
matrix = (matrix - mean) / std
matrix = _median_filter(matrix, self.config.median_filter_width)
# Average the different cross-attention heads.
matrix = matrix.mean(dim=0)
else:
matrix = weights[batch_idx]
text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] * time_precision
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
return timestamps
class WhisperDecoderWrapper(WhisperPreTrainedModel): class WhisperDecoderWrapper(WhisperPreTrainedModel):
""" """

View File

@@ -530,10 +530,21 @@ class WhisperTokenizer(PreTrainedTokenizer):
""" """
timestamp_begin = self.all_special_ids[-1] + 1 timestamp_begin = self.all_special_ids[-1] + 1
outputs = [[]] outputs = [[]]
cur_max_timestamp = 0.0
prev_segments_len = 0.0
for token in token_ids: for token in token_ids:
if token >= timestamp_begin: if token >= timestamp_begin:
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" timestamp = float((token - timestamp_begin) * time_precision)
outputs.append(timestamp)
if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
cur_max_timestamp = timestamp
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
outputs.append([]) outputs.append([])
else: else:
outputs[-1].append(token) outputs[-1].append(token)
@@ -631,7 +642,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None, clean_up_tokenization_spaces: bool = None,
output_offsets: bool = False, output_offsets: bool = False,
time_precision=0.02, time_precision: float = 0.02,
decode_with_timestamps: bool = False, decode_with_timestamps: bool = False,
normalize: bool = False, normalize: bool = False,
basic_normalize: bool = False, basic_normalize: bool = False,

View File

@@ -224,10 +224,21 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
""" """
timestamp_begin = self.all_special_ids[-1] + 1 timestamp_begin = self.all_special_ids[-1] + 1
outputs = [[]] outputs = [[]]
cur_max_timestamp = 0.0
prev_segments_len = 0.0
for token in token_ids: for token in token_ids:
if token >= timestamp_begin: if token >= timestamp_begin:
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" timestamp = float((token - timestamp_begin) * time_precision)
outputs.append(timestamp)
if timestamp < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp
cur_max_timestamp = timestamp
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
outputs.append([]) outputs.append([])
else: else:
outputs[-1].append(token) outputs[-1].append(token)
@@ -330,7 +341,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None, clean_up_tokenization_spaces: bool = None,
output_offsets: bool = False, output_offsets: bool = False,
time_precision=0.02, time_precision: float = 0.02,
decode_with_timestamps: bool = False, decode_with_timestamps: bool = False,
normalize: bool = False, normalize: bool = False,
basic_normalize: bool = False, basic_normalize: bool = False,

File diff suppressed because one or more lines are too long

View File

@@ -1152,7 +1152,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
@slow @slow
def test_whisper_longform(self): def test_whisper_longform(self):
# fmt: off # fmt: off
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!""" EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."""
# fmt: on # fmt: on
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")