[Whisper] Add sequential longform decoding (#27492)

* [Whisper] Add seq gen

* [Whisper] Add seq gen

* more debug

* Fix whisper logit processor

* Improve whisper code further

* Fix more

* more debug

* more debug

* Improve further

* Add tests

* Prep for batch size > 1

* Get batch_size>1 working

* Correct more

* Add extensive tests

* more debug

* more debug

* more debug

* add more tests

* more debug

* Apply suggestions from code review

* more debug

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* Add more examples

* add comments to explain the code better

* fix more

* add comments to explain the code better

* add comments to explain the code better

* correct

* correct

* finalize

* Apply suggestions from code review

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen
2023-11-22 13:27:34 +01:00
committed by GitHub
parent b2c63c79c3
commit 4151fbb49c
5 changed files with 836 additions and 83 deletions

View File

@@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
``` python
@@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
```
"""
def __init__(self, generate_config): # support for the kwargs
def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")
if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = list(input_ids[k, self.begin_index :].tolist())
sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
@@ -1549,8 +1556,23 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf")
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1
scores[k, self.timestamp_begin : timestamp_last] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")
if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")
@@ -1559,7 +1581,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf")
return scores

View File

@@ -15,6 +15,7 @@
""" PyTorch Whisper model."""
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
@@ -1111,6 +1112,13 @@ class WhisperEncoder(WhisperPreTrainedModel):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
if input_features.shape[-1] != expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1723,7 +1731,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
def generate(
self,
inputs: Optional[torch.Tensor] = None,
input_features: Optional[torch.Tensor] = None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
@@ -1734,12 +1742,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
language=None,
is_multilingual=None,
prompt_ids: Optional[torch.Tensor] = None,
return_token_timestamps=None,
num_segment_frames: Optional[int] = None,
return_token_timestamps: Optional[bool] = None,
return_segments: bool = False,
attention_mask: Optional[torch.Tensor] = None,
time_precision: int = 0.02,
return_dict_in_generate: Optional[bool] = None,
**kwargs,
):
"""
Generates sequences of token ids for models with a language modeling head.
Transcribes or translates passed mel input features to a sequence of token ids.
<Tip warning={true}>
@@ -1801,46 +1813,162 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
Whether to return token-level timestamps with the text. This can be used with or without the
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
words.
return_segments (`bool`, *optional*, defaults to `False`):
Whether to additionally return a list of all segments. Note that this option can only be enabled
when doing long-form transcription.
attention_mask (`torch.Tensor`, *optional*):
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
`return_segments` is set True. In this case the generation outputs of each segment is added to each
segment.
kwargs (`Dict[str, Any]`, *optional*):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
[`~utils.ModelOutput`] types are:
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
- [`~generation.GreedySearchDecoderOnlyOutput`],
- [`~generation.SampleDecoderOnlyOutput`],
- [`~generation.BeamSearchDecoderOnlyOutput`],
- [`~generation.BeamSampleDecoderOnlyOutput`]
If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
else only the generated output sequence ids are returned.
Example:
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset, Audio
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> model.cuda()
>>> # load audios > 30 seconds
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
>>> # resample to 16kHz
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
>>> # take first 8 audios and retrieve array
>>> audio = ds[:8]["audio"]
>>> audio = [x["array"] for x in audio]
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
>>> inputs = inputs.to("cuda", torch.float32)
>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
>>> transcription[0]
' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!'
```
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```
"""
if "inputs" in kwargs:
input_features = kwargs.pop("inputs")
warnings.warn(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
if generation_config is None:
generation_config = self.generation_config
if return_timestamps is not None:
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
if num_segment_frames is None:
num_segment_frames = input_stride * self.config.max_source_positions
# 1. Check whether we're in shortform or longform mode
if input_features is not None:
total_input_frames = input_features.shape[-1]
elif "encoder_outputs" in kwargs:
encoder_outputs_shape = (
kwargs["encoder_outputs"][0].shape
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
else kwargs["encoder_outputs"].shape
)
total_input_frames = encoder_outputs_shape[1] * input_stride
else:
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
is_shortform = total_input_frames <= num_segment_frames
# 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
if return_timestamps is True:
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You are trying to return timestamps, but the generation config is not properly set. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
)
generation_config.return_timestamps = return_timestamps
elif not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
)
if not hasattr(generation_config, "no_timestamps_token_id"):
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
"requires the generation config to have `no_timestamps_token_id` correctly. "
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
"or make sure to pass no more than 3000 mel input features."
)
logger.info("Setting `return_timestamps=True` for long-form generation.")
generation_config.return_timestamps = True
else:
generation_config.return_timestamps = False
# 3. Make sure to correctly set language-related parameters
if is_multilingual is not None:
if not hasattr(generation_config, "is_multilingual"):
raise ValueError(
@@ -1875,8 +2003,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
)
generation_config.task = task
# 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps`
forced_decoder_ids = None
# Legacy code for backward compatibility
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
forced_decoder_ids = self.config.forced_decoder_ids
@@ -1961,12 +2089,9 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
generation_config.forced_decoder_ids = forced_decoder_ids
if generation_config.return_timestamps:
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
if return_token_timestamps:
kwargs["output_attentions"] = True
kwargs["return_dict_in_generate"] = True
return_dict_in_generate = True
if getattr(generation_config, "task", None) == "translate":
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
@@ -1979,13 +2104,34 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
if kwargs.get("num_frames") is not None:
generation_config.num_frames = kwargs.pop("num_frames")
if generation_config.return_timestamps is True:
last_forced_decoder_ids = (
generation_config.forced_decoder_ids[-1][-1]
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids
else None
)
if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id:
# remove no_timestamp to be forcefully generated if we want to return timestamps
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
forced_decoder_ids = generation_config.forced_decoder_ids[:-1]
# Make sure that if list is empty we set it to None
generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids
timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
logits_processor = (
timestamp_processor if logits_processor is None else timestamp_processor + logits_processor
)
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
if is_shortform:
outputs = super().generate(
inputs,
input_features,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
@@ -1997,6 +2143,229 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
return outputs
# 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated
# timestamp tokens
# 6.1 Set running parameters for while loop
if not return_segments and return_dict_in_generate:
raise ValueError(
"Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`"
)
# if input is longer than 30 seconds we default to long-form generation
timestamp_begin = self.generation_config.no_timestamps_token_id + 1
# input stride is mel frames per encoder output vector which is the product of all conv strides
batch_size = input_features.shape[0]
if batch_size > 1 and attention_mask is None:
raise ValueError(
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
)
elif batch_size > 1:
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
seek = torch.zeros((batch_size,), dtype=torch.long)
else:
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames
seek = torch.zeros((1,), dtype=torch.long)
current_segments = [[] for _ in range(batch_size)]
cur_to_prev_index_map = list(range(batch_size))
# batch size can decrease during the run
cur_bsz = prev_bsz = batch_size
# 6.2 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
prev_bsz = cur_bsz
# 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
# to know which original audio is being decoded
new_cur_to_prev_index_map = []
for i in range(prev_bsz):
prev_i = cur_to_prev_index_map[i]
if seek[prev_i] >= max_frames[prev_i]:
cut_index = i + (cur_bsz - prev_bsz)
cur_bsz -= 1
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
else:
# cut out index that goes away
new_cur_to_prev_index_map.append(prev_i)
# 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
cur_to_prev_index_map = new_cur_to_prev_index_map
time_offset = seek * time_precision / input_stride
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
# 6.5 Make sure that all inputs are padded to the same input length
segment_input = []
for i in range(cur_bsz):
prev_i = cur_to_prev_index_map[i]
segment_input_slice = input_features[
i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]
]
if segment_input_slice.shape[-1] < num_segment_frames:
# pad to 3000 if necessary
segment_input_slice = F.pad(
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
)
segment_input.append(segment_input_slice)
segment_input = torch.cat(segment_input, dim=0)
# 6.6 Batch generate current chunk
seek_outputs = super().generate(
segment_input,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
**kwargs,
)
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
num_frames = getattr(generation_config, "num_frames", None)
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
)
if return_dict_in_generate:
seek_sequences = seek_outputs["sequences"]
seek_outputs = [
{k: v[i] for k, v in seek_outputs.items()}
for i in range(next(iter(seek_outputs.values())).size(0))
]
else:
seek_sequences = seek_outputs
# 6.7 Loop over each decoded audio individually as each decoding can be of a different length
for i, seek_sequence in enumerate(seek_sequences):
prev_i = cur_to_prev_index_map[i]
# make sure we cut a predicted EOS token if we are not finished with the generation yet
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id:
seek_sequence = seek_sequence[:-1]
# remove all padding tokens
if seek_sequence[-1] == self.generation_config.pad_token_id:
num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum()
seek_sequence = seek_sequence[:-num_paddings]
segments, segment_offset = self._retrieve_segment(
seek_sequence=seek_sequence,
seek_outputs=seek_outputs,
time_offset=time_offset,
timestamp_begin=timestamp_begin,
seek_num_frames=seek_num_frames,
cur_bsz=cur_bsz,
time_precision=time_precision,
input_stride=input_stride,
prev_idx=prev_i,
idx=i,
)
current_segments[prev_i] += segments
seek[prev_i] += segment_offset
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
sequences = []
max_total_length = 0
for current_segment_list in current_segments:
sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1))
max_total_length = max(max_total_length, len(sequences[-1]))
for i in range(batch_size):
sequences[i] = F.pad(
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
)
sequences = torch.stack(sequences, dim=0)
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": current_segments}
return sequences
@staticmethod
def _retrieve_segment(
seek_sequence,
seek_outputs,
time_offset,
timestamp_begin,
seek_num_frames,
cur_bsz,
time_precision,
input_stride,
prev_idx,
idx,
):
# find the predicted "end of segment" predictions of Whisper
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]]
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
if len(timestamp_segment_indices) > 0:
# if the output contains two consecutive timestamp tokens
slices = timestamp_segment_indices.tolist()
segments = []
if single_timestamp_ending:
slices.append(len(seek_sequence))
last_slice = 0
# Add each segment to list of all segments
for current_slice in slices:
sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1]
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
segment_offset = seek_num_frames[prev_idx]
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
# here we throw away all predictions after the last predicted "end of segment"
# since we are cutting right in the middle of an audio
last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin
segment_offset = last_timestamp_pos * input_stride
else:
# If whisper does not predict any "end of segment" token, then
# the whole decoding is considered a segment and we add it to the list of segments
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
last_timestamp_pos = seek_num_frames[prev_idx]
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
segments = [
{
"start": time_offset[prev_idx],
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
"tokens": seek_sequence,
"result": seek_outputs[idx],
}
]
segment_offset = seek_num_frames[prev_idx]
return segments, segment_offset
def prepare_inputs_for_generation(
self,
decoder_input_ids,
@@ -2229,7 +2598,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
>>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
>>> # decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
```"""

View File

@@ -507,10 +507,20 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
inputs, self.feature_extractor, chunk_len, stride_left, stride_right, rescale, self.torch_dtype
):
yield item
else:
if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
processed = self.feature_extractor(
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
truncation=False,
padding="longest",
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
if stride is not None:
@@ -551,8 +561,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if stride is not None:
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
else:
generate_kwargs["encoder_outputs"] = encoder(inputs, attention_mask=attention_mask)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
**generate_kwargs,
)

File diff suppressed because one or more lines are too long

View File

@@ -16,7 +16,7 @@ import unittest
import numpy as np
import pytest
from datasets import load_dataset
from datasets import Audio, load_dataset
from huggingface_hub import hf_hub_download, snapshot_download
from transformers import (
@@ -329,16 +329,16 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{"text": " Conquered", "timestamp": (0.5, 1.2)},
{"text": " returned", "timestamp": (1.2, 1.64)},
{"text": " to", "timestamp": (1.64, 1.84)},
{"text": " its", "timestamp": (1.84, 2.02)},
{"text": " place", "timestamp": (2.02, 2.28)},
{"text": " amidst", "timestamp": (2.28, 2.78)},
{"text": " the", "timestamp": (2.78, 2.96)},
{"text": " tents.", "timestamp": (2.96, 3.48)},
'text': ' Conquered returned to its place amidst the tents.',
'chunks': [
{'text': ' Conquered', 'timestamp': (0.5, 1.2)},
{'text': ' returned', 'timestamp': (1.2, 1.64)},
{'text': ' to', 'timestamp': (1.64, 1.84)},
{'text': ' its', 'timestamp': (1.84, 2.02)},
{'text': ' place', 'timestamp': (2.02, 2.28)},
{'text': ' amidst', 'timestamp': (2.28, 2.8)},
{'text': ' the', 'timestamp': (2.8, 2.98)},
{'text': ' tents.', 'timestamp': (2.98, 3.48)},
],
},
)
@@ -776,27 +776,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
'text': ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
'chunks': [
{'text': ' Mr.', 'timestamp': (0.38, 1.04)},
{'text': ' Quilter', 'timestamp': (1.04, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' of', 'timestamp': (1.98, 2.32)},
{'text': ' the', 'timestamp': (2.32, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' classes,', 'timestamp': (2.56, 3.4)},
{'text': ' and', 'timestamp': (3.4, 3.54)},
{'text': ' we', 'timestamp': (3.54, 3.62)},
{'text': ' are', 'timestamp': (3.62, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
{'text': ' welcome', 'timestamp': (4.26, 4.56)},
{'text': ' his', 'timestamp': (4.56, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 5.84)}
]
}
)
# fmt: on
@@ -1087,6 +1087,34 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_whisper_longform(self):
# fmt: off
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!"""
# fmt: on
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model = model.to("cuda")
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device="cuda:0",
)
ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[:1]["audio"]
result = pipe(audio)[0]["text"]
assert result == EXPECTED_RESULT
@require_torch
@slow
def test_chunking_and_timestamps(self):
@@ -1355,7 +1383,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
out,
{
"chunks": [
{"text": "", "timestamp": (18.94, 0.0)},
{"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",