[Whisper] Add sequential longform decoding (#27492)
* [Whisper] Add seq gen * [Whisper] Add seq gen * more debug * Fix whisper logit processor * Improve whisper code further * Fix more * more debug * more debug * Improve further * Add tests * Prep for batch size > 1 * Get batch_size>1 working * Correct more * Add extensive tests * more debug * more debug * more debug * add more tests * more debug * Apply suggestions from code review * more debug * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * Add more examples * add comments to explain the code better * fix more * add comments to explain the code better * add comments to explain the code better * correct * correct * finalize * Apply suggestions from code review * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
b2c63c79c3
commit
4151fbb49c
@@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
||||
predicting timestamps that are too far in the future.
|
||||
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
|
||||
|
||||
Examples:
|
||||
``` python
|
||||
@@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, generate_config): # support for the kwargs
|
||||
def __init__(
|
||||
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
|
||||
): # support for the kwargs
|
||||
self.eos_token_id = generate_config.eos_token_id
|
||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
||||
|
||||
self.begin_index = len(generate_config.forced_decoder_ids) + 2
|
||||
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
|
||||
self.begin_index -= 1
|
||||
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
|
||||
# this variable is mostly just used for testing
|
||||
self._detect_timestamp_from_logprob = (
|
||||
_detect_timestamp_from_logprob
|
||||
if _detect_timestamp_from_logprob is not None
|
||||
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
||||
)
|
||||
|
||||
self.begin_index = (
|
||||
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
|
||||
)
|
||||
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
||||
|
||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
scores[:, self.no_timestamps_token_id] = -float("inf")
|
||||
|
||||
if input_ids.shape[1] == self.begin_index - 1:
|
||||
scores[:, :] = -float("inf")
|
||||
scores[:, self.timestamp_begin] = 0
|
||||
return scores
|
||||
|
||||
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
|
||||
for k in range(input_ids.shape[0]):
|
||||
seq = list(input_ids[k, self.begin_index :].tolist())
|
||||
sampled_tokens = input_ids[k, self.begin_index :]
|
||||
seq = list(sampled_tokens.tolist())
|
||||
|
||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
|
||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
|
||||
|
||||
@@ -1549,8 +1556,23 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
else: # cannot be normal text tokens
|
||||
scores[k, : self.eos_token_id] = -float("inf")
|
||||
|
||||
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
|
||||
if timestamps.numel() > 0:
|
||||
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
|
||||
if last_was_timestamp and not penultimate_was_timestamp:
|
||||
timestamp_last = timestamps[-1]
|
||||
else:
|
||||
# Avoid to emit <|0.00|> again
|
||||
timestamp_last = timestamps[-1] + 1
|
||||
|
||||
scores[k, self.timestamp_begin : timestamp_last] = -float("inf")
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
|
||||
if input_ids.shape[1] == self.begin_index:
|
||||
scores[:, : self.timestamp_begin] = -float("inf")
|
||||
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
|
||||
scores[:, last_allowed + 1 :] = -float("inf")
|
||||
|
||||
@@ -1559,7 +1581,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
||||
for k in range(input_ids.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
|
||||
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
|
||||
scores[k, : self.timestamp_begin] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
""" PyTorch Whisper model."""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -1111,6 +1112,13 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
|
||||
if input_features.shape[-1] != expected_seq_length:
|
||||
raise ValueError(
|
||||
f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
|
||||
)
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
@@ -1723,7 +1731,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
|
||||
def generate(
|
||||
self,
|
||||
inputs: Optional[torch.Tensor] = None,
|
||||
input_features: Optional[torch.Tensor] = None,
|
||||
generation_config=None,
|
||||
logits_processor=None,
|
||||
stopping_criteria=None,
|
||||
@@ -1734,12 +1742,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
language=None,
|
||||
is_multilingual=None,
|
||||
prompt_ids: Optional[torch.Tensor] = None,
|
||||
return_token_timestamps=None,
|
||||
num_segment_frames: Optional[int] = None,
|
||||
return_token_timestamps: Optional[bool] = None,
|
||||
return_segments: bool = False,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
time_precision: int = 0.02,
|
||||
return_dict_in_generate: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Generates sequences of token ids for models with a language modeling head.
|
||||
Transcribes or translates passed mel input features to a sequence of token ids.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@@ -1801,46 +1813,162 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
Whether to return token-level timestamps with the text. This can be used with or without the
|
||||
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
||||
words.
|
||||
return_segments (`bool`, *optional*, defaults to `False`):
|
||||
Whether to additionally return a list of all segments. Note that this option can only be enabled
|
||||
when doing long-form transcription.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
|
||||
time_precision (`int`, *optional*, defaults to 0.02):
|
||||
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
|
||||
for 20 ms.
|
||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
|
||||
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
|
||||
`return_segments` is set True. In this case the generation outputs of each segment is added to each
|
||||
segment.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
||||
|
||||
Return:
|
||||
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
|
||||
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
|
||||
|
||||
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
|
||||
|
||||
- [`~generation.GreedySearchDecoderOnlyOutput`],
|
||||
- [`~generation.SampleDecoderOnlyOutput`],
|
||||
- [`~generation.BeamSearchDecoderOnlyOutput`],
|
||||
- [`~generation.BeamSampleDecoderOnlyOutput`]
|
||||
|
||||
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
|
||||
[`~utils.ModelOutput`] types are:
|
||||
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
|
||||
|
||||
- [`~generation.GreedySearchEncoderDecoderOutput`],
|
||||
- [`~generation.SampleEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSearchEncoderDecoderOutput`],
|
||||
- [`~generation.BeamSampleEncoderDecoderOutput`]
|
||||
|
||||
else only the generated output sequence ids are returned.
|
||||
|
||||
Example:
|
||||
|
||||
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
||||
>>> from datasets import load_dataset, Audio
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model.cuda()
|
||||
|
||||
>>> # load audios > 30 seconds
|
||||
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
>>> # resample to 16kHz
|
||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
||||
>>> # take first 8 audios and retrieve array
|
||||
>>> audio = ds[:8]["audio"]
|
||||
>>> audio = [x["array"] for x in audio]
|
||||
|
||||
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
|
||||
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
|
||||
>>> inputs = inputs.to("cuda", torch.float32)
|
||||
|
||||
>>> # transcribe audio to ids
|
||||
>>> generated_ids = model.generate(**inputs)
|
||||
|
||||
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
>>> transcription[0]
|
||||
' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!'
|
||||
```
|
||||
|
||||
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
|
||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> input_features = inputs.input_features
|
||||
|
||||
>>> generated_ids = model.generate(inputs=input_features)
|
||||
|
||||
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
>>> transcription
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
if "inputs" in kwargs:
|
||||
input_features = kwargs.pop("inputs")
|
||||
warnings.warn(
|
||||
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
return_dict_in_generate = (
|
||||
return_dict_in_generate
|
||||
if return_dict_in_generate is not None
|
||||
else self.generation_config.return_dict_in_generate
|
||||
)
|
||||
|
||||
if generation_config is None:
|
||||
generation_config = self.generation_config
|
||||
|
||||
if return_timestamps is not None:
|
||||
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
||||
if num_segment_frames is None:
|
||||
num_segment_frames = input_stride * self.config.max_source_positions
|
||||
|
||||
# 1. Check whether we're in shortform or longform mode
|
||||
if input_features is not None:
|
||||
total_input_frames = input_features.shape[-1]
|
||||
elif "encoder_outputs" in kwargs:
|
||||
encoder_outputs_shape = (
|
||||
kwargs["encoder_outputs"][0].shape
|
||||
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
|
||||
else kwargs["encoder_outputs"].shape
|
||||
)
|
||||
total_input_frames = encoder_outputs_shape[1] * input_stride
|
||||
else:
|
||||
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
|
||||
|
||||
is_shortform = total_input_frames <= num_segment_frames
|
||||
|
||||
# 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
|
||||
if return_timestamps is True:
|
||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You are trying to return timestamps, but the generation config is not properly set. "
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
)
|
||||
|
||||
generation_config.return_timestamps = return_timestamps
|
||||
elif not is_shortform:
|
||||
if return_timestamps is False:
|
||||
raise ValueError(
|
||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
||||
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
||||
)
|
||||
|
||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
||||
raise ValueError(
|
||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
||||
"requires the generation config to have `no_timestamps_token_id` correctly. "
|
||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
||||
"or make sure to pass no more than 3000 mel input features."
|
||||
)
|
||||
|
||||
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
||||
generation_config.return_timestamps = True
|
||||
else:
|
||||
generation_config.return_timestamps = False
|
||||
|
||||
# 3. Make sure to correctly set language-related parameters
|
||||
if is_multilingual is not None:
|
||||
if not hasattr(generation_config, "is_multilingual"):
|
||||
raise ValueError(
|
||||
@@ -1875,8 +2003,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
)
|
||||
generation_config.task = task
|
||||
|
||||
# 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps`
|
||||
forced_decoder_ids = None
|
||||
|
||||
# Legacy code for backward compatibility
|
||||
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
@@ -1961,12 +2089,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
||||
|
||||
if generation_config.return_timestamps:
|
||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
|
||||
if return_token_timestamps:
|
||||
kwargs["output_attentions"] = True
|
||||
kwargs["return_dict_in_generate"] = True
|
||||
return_dict_in_generate = True
|
||||
|
||||
if getattr(generation_config, "task", None) == "translate":
|
||||
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
||||
@@ -1979,13 +2104,34 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
if kwargs.get("num_frames") is not None:
|
||||
generation_config.num_frames = kwargs.pop("num_frames")
|
||||
|
||||
if generation_config.return_timestamps is True:
|
||||
last_forced_decoder_ids = (
|
||||
generation_config.forced_decoder_ids[-1][-1]
|
||||
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids
|
||||
else None
|
||||
)
|
||||
if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id:
|
||||
# remove no_timestamp to be forcefully generated if we want to return timestamps
|
||||
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
|
||||
forced_decoder_ids = generation_config.forced_decoder_ids[:-1]
|
||||
# Make sure that if list is empty we set it to None
|
||||
generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids
|
||||
|
||||
timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||
logits_processor = (
|
||||
timestamp_processor if logits_processor is None else timestamp_processor + logits_processor
|
||||
)
|
||||
|
||||
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
|
||||
if is_shortform:
|
||||
outputs = super().generate(
|
||||
inputs,
|
||||
input_features,
|
||||
generation_config,
|
||||
logits_processor,
|
||||
stopping_criteria,
|
||||
prefix_allowed_tokens_fn,
|
||||
synced_gpus,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1997,6 +2143,229 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
||||
|
||||
return outputs
|
||||
|
||||
# 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated
|
||||
# timestamp tokens
|
||||
# 6.1 Set running parameters for while loop
|
||||
if not return_segments and return_dict_in_generate:
|
||||
raise ValueError(
|
||||
"Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`"
|
||||
)
|
||||
|
||||
# if input is longer than 30 seconds we default to long-form generation
|
||||
timestamp_begin = self.generation_config.no_timestamps_token_id + 1
|
||||
# input stride is mel frames per encoder output vector which is the product of all conv strides
|
||||
batch_size = input_features.shape[0]
|
||||
|
||||
if batch_size > 1 and attention_mask is None:
|
||||
raise ValueError(
|
||||
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
||||
)
|
||||
elif batch_size > 1:
|
||||
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
|
||||
seek = torch.zeros((batch_size,), dtype=torch.long)
|
||||
else:
|
||||
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames
|
||||
seek = torch.zeros((1,), dtype=torch.long)
|
||||
|
||||
current_segments = [[] for _ in range(batch_size)]
|
||||
cur_to_prev_index_map = list(range(batch_size))
|
||||
|
||||
# batch size can decrease during the run
|
||||
cur_bsz = prev_bsz = batch_size
|
||||
|
||||
# 6.2 Transcribe audio until we reach the end of all input audios
|
||||
while (seek < max_frames).any():
|
||||
prev_bsz = cur_bsz
|
||||
|
||||
# 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
|
||||
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
|
||||
# to know which original audio is being decoded
|
||||
new_cur_to_prev_index_map = []
|
||||
for i in range(prev_bsz):
|
||||
prev_i = cur_to_prev_index_map[i]
|
||||
if seek[prev_i] >= max_frames[prev_i]:
|
||||
cut_index = i + (cur_bsz - prev_bsz)
|
||||
cur_bsz -= 1
|
||||
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
|
||||
else:
|
||||
# cut out index that goes away
|
||||
new_cur_to_prev_index_map.append(prev_i)
|
||||
|
||||
# 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
|
||||
cur_to_prev_index_map = new_cur_to_prev_index_map
|
||||
time_offset = seek * time_precision / input_stride
|
||||
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
||||
|
||||
# 6.5 Make sure that all inputs are padded to the same input length
|
||||
segment_input = []
|
||||
for i in range(cur_bsz):
|
||||
prev_i = cur_to_prev_index_map[i]
|
||||
segment_input_slice = input_features[
|
||||
i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]
|
||||
]
|
||||
|
||||
if segment_input_slice.shape[-1] < num_segment_frames:
|
||||
# pad to 3000 if necessary
|
||||
segment_input_slice = F.pad(
|
||||
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
|
||||
)
|
||||
|
||||
segment_input.append(segment_input_slice)
|
||||
|
||||
segment_input = torch.cat(segment_input, dim=0)
|
||||
|
||||
# 6.6 Batch generate current chunk
|
||||
seek_outputs = super().generate(
|
||||
segment_input,
|
||||
generation_config,
|
||||
logits_processor,
|
||||
stopping_criteria,
|
||||
prefix_allowed_tokens_fn,
|
||||
synced_gpus,
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||
num_frames = getattr(generation_config, "num_frames", None)
|
||||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
||||
)
|
||||
|
||||
if return_dict_in_generate:
|
||||
seek_sequences = seek_outputs["sequences"]
|
||||
seek_outputs = [
|
||||
{k: v[i] for k, v in seek_outputs.items()}
|
||||
for i in range(next(iter(seek_outputs.values())).size(0))
|
||||
]
|
||||
else:
|
||||
seek_sequences = seek_outputs
|
||||
|
||||
# 6.7 Loop over each decoded audio individually as each decoding can be of a different length
|
||||
for i, seek_sequence in enumerate(seek_sequences):
|
||||
prev_i = cur_to_prev_index_map[i]
|
||||
|
||||
# make sure we cut a predicted EOS token if we are not finished with the generation yet
|
||||
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
|
||||
if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id:
|
||||
seek_sequence = seek_sequence[:-1]
|
||||
|
||||
# remove all padding tokens
|
||||
if seek_sequence[-1] == self.generation_config.pad_token_id:
|
||||
num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum()
|
||||
seek_sequence = seek_sequence[:-num_paddings]
|
||||
|
||||
segments, segment_offset = self._retrieve_segment(
|
||||
seek_sequence=seek_sequence,
|
||||
seek_outputs=seek_outputs,
|
||||
time_offset=time_offset,
|
||||
timestamp_begin=timestamp_begin,
|
||||
seek_num_frames=seek_num_frames,
|
||||
cur_bsz=cur_bsz,
|
||||
time_precision=time_precision,
|
||||
input_stride=input_stride,
|
||||
prev_idx=prev_i,
|
||||
idx=i,
|
||||
)
|
||||
|
||||
current_segments[prev_i] += segments
|
||||
seek[prev_i] += segment_offset
|
||||
|
||||
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
||||
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
||||
sequences = []
|
||||
max_total_length = 0
|
||||
for current_segment_list in current_segments:
|
||||
sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1))
|
||||
max_total_length = max(max_total_length, len(sequences[-1]))
|
||||
|
||||
for i in range(batch_size):
|
||||
sequences[i] = F.pad(
|
||||
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
|
||||
)
|
||||
|
||||
sequences = torch.stack(sequences, dim=0)
|
||||
|
||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
||||
if return_segments:
|
||||
return {"sequences": sequences, "segments": current_segments}
|
||||
|
||||
return sequences
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_segment(
|
||||
seek_sequence,
|
||||
seek_outputs,
|
||||
time_offset,
|
||||
timestamp_begin,
|
||||
seek_num_frames,
|
||||
cur_bsz,
|
||||
time_precision,
|
||||
input_stride,
|
||||
prev_idx,
|
||||
idx,
|
||||
):
|
||||
# find the predicted "end of segment" predictions of Whisper
|
||||
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
||||
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]]
|
||||
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
|
||||
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
||||
# "end of segment" prediction and slice the decoding into segments accordingly
|
||||
if len(timestamp_segment_indices) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = timestamp_segment_indices.tolist()
|
||||
segments = []
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(seek_sequence))
|
||||
|
||||
last_slice = 0
|
||||
# Add each segment to list of all segments
|
||||
for current_slice in slices:
|
||||
sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1]
|
||||
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
|
||||
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
|
||||
segments.append(
|
||||
{
|
||||
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
|
||||
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
|
||||
"tokens": sliced_tokens,
|
||||
"result": seek_outputs[idx],
|
||||
}
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
segment_offset = seek_num_frames[prev_idx]
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
# here we throw away all predictions after the last predicted "end of segment"
|
||||
# since we are cutting right in the middle of an audio
|
||||
last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin
|
||||
segment_offset = last_timestamp_pos * input_stride
|
||||
else:
|
||||
# If whisper does not predict any "end of segment" token, then
|
||||
# the whole decoding is considered a segment and we add it to the list of segments
|
||||
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
|
||||
last_timestamp_pos = seek_num_frames[prev_idx]
|
||||
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
|
||||
|
||||
segments = [
|
||||
{
|
||||
"start": time_offset[prev_idx],
|
||||
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
|
||||
"tokens": seek_sequence,
|
||||
"result": seek_outputs[idx],
|
||||
}
|
||||
]
|
||||
segment_offset = seek_num_frames[prev_idx]
|
||||
|
||||
return segments, segment_offset
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
decoder_input_ids,
|
||||
@@ -2229,7 +2598,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
||||
>>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
|
||||
|
||||
>>> # decode token ids to text
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
||||
>>> transcription
|
||||
' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
|
||||
```"""
|
||||
|
||||
@@ -507,10 +507,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
|
||||
processed = self.feature_extractor(
|
||||
inputs,
|
||||
sampling_rate=self.feature_extractor.sampling_rate,
|
||||
truncation=False,
|
||||
padding="longest",
|
||||
return_tensors="pt",
|
||||
)
|
||||
else:
|
||||
processed = self.feature_extractor(
|
||||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
|
||||
)
|
||||
|
||||
if self.torch_dtype is not None:
|
||||
processed = processed.to(dtype=self.torch_dtype)
|
||||
if stride is not None:
|
||||
@@ -551,8 +561,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if stride is not None:
|
||||
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
||||
|
||||
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
||||
generate_kwargs["input_features"] = inputs
|
||||
else:
|
||||
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
|
||||
|
||||
tokens = self.model.generate(
|
||||
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from datasets import Audio, load_dataset
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from transformers import (
|
||||
@@ -329,16 +329,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"text": " Conquered returned to its place amidst the tents.",
|
||||
"chunks": [
|
||||
{"text": " Conquered", "timestamp": (0.5, 1.2)},
|
||||
{"text": " returned", "timestamp": (1.2, 1.64)},
|
||||
{"text": " to", "timestamp": (1.64, 1.84)},
|
||||
{"text": " its", "timestamp": (1.84, 2.02)},
|
||||
{"text": " place", "timestamp": (2.02, 2.28)},
|
||||
{"text": " amidst", "timestamp": (2.28, 2.78)},
|
||||
{"text": " the", "timestamp": (2.78, 2.96)},
|
||||
{"text": " tents.", "timestamp": (2.96, 3.48)},
|
||||
'text': ' Conquered returned to its place amidst the tents.',
|
||||
'chunks': [
|
||||
{'text': ' Conquered', 'timestamp': (0.5, 1.2)},
|
||||
{'text': ' returned', 'timestamp': (1.2, 1.64)},
|
||||
{'text': ' to', 'timestamp': (1.64, 1.84)},
|
||||
{'text': ' its', 'timestamp': (1.84, 2.02)},
|
||||
{'text': ' place', 'timestamp': (2.02, 2.28)},
|
||||
{'text': ' amidst', 'timestamp': (2.28, 2.8)},
|
||||
{'text': ' the', 'timestamp': (2.8, 2.98)},
|
||||
{'text': ' tents.', 'timestamp': (2.98, 3.48)},
|
||||
],
|
||||
},
|
||||
)
|
||||
@@ -776,27 +776,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"chunks": [
|
||||
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
|
||||
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
|
||||
'text': ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
|
||||
'chunks': [
|
||||
{'text': ' Mr.', 'timestamp': (0.38, 1.04)},
|
||||
{'text': ' Quilter', 'timestamp': (1.04, 1.18)},
|
||||
{'text': ' is', 'timestamp': (1.18, 1.44)},
|
||||
{'text': ' the', 'timestamp': (1.44, 1.58)},
|
||||
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
|
||||
{'text': ' of', 'timestamp': (1.98, 2.3)},
|
||||
{'text': ' the', 'timestamp': (2.3, 2.46)},
|
||||
{'text': ' of', 'timestamp': (1.98, 2.32)},
|
||||
{'text': ' the', 'timestamp': (2.32, 2.46)},
|
||||
{'text': ' middle', 'timestamp': (2.46, 2.56)},
|
||||
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
|
||||
{'text': ' and', 'timestamp': (3.38, 3.52)},
|
||||
{'text': ' we', 'timestamp': (3.52, 3.6)},
|
||||
{'text': ' are', 'timestamp': (3.6, 3.72)},
|
||||
{'text': ' classes,', 'timestamp': (2.56, 3.4)},
|
||||
{'text': ' and', 'timestamp': (3.4, 3.54)},
|
||||
{'text': ' we', 'timestamp': (3.54, 3.62)},
|
||||
{'text': ' are', 'timestamp': (3.62, 3.72)},
|
||||
{'text': ' glad', 'timestamp': (3.72, 4.0)},
|
||||
{'text': ' to', 'timestamp': (4.0, 4.26)},
|
||||
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
|
||||
{'text': ' his', 'timestamp': (4.54, 4.92)},
|
||||
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
|
||||
],
|
||||
},
|
||||
{'text': ' welcome', 'timestamp': (4.26, 4.56)},
|
||||
{'text': ' his', 'timestamp': (4.56, 4.92)},
|
||||
{'text': ' gospel.', 'timestamp': (4.92, 5.84)}
|
||||
]
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@@ -1087,6 +1087,34 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_whisper_longform(self):
|
||||
# fmt: off
|
||||
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!"""
|
||||
# fmt: on
|
||||
|
||||
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model = model.to("cuda")
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
max_new_tokens=128,
|
||||
device="cuda:0",
|
||||
)
|
||||
|
||||
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
||||
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
||||
audio = ds[:1]["audio"]
|
||||
|
||||
result = pipe(audio)[0]["text"]
|
||||
|
||||
assert result == EXPECTED_RESULT
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_chunking_and_timestamps(self):
|
||||
@@ -1355,7 +1383,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
out,
|
||||
{
|
||||
"chunks": [
|
||||
{"text": "", "timestamp": (18.94, 0.0)},
|
||||
{"text": "", "timestamp": (18.94, 0.02)},
|
||||
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
|
||||
],
|
||||
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
|
||||
|
||||
Reference in New Issue
Block a user