[Whisper] Finalize batched SOTA long-form generation (#27658)
* finalize * make fix copies whisper * [Tests] Make sure that we don't run tests mulitple times * Update src/transformers/models/whisper/modeling_whisper.py * [Tests] Make sure that we don't run tests mulitple times * fix more * improve * improve * improve further * improve more * improve * fix more * git commit and git push * fix more * fix more * fix more * New try * Fix more whisper stuff * Improve * correct more * correct more * correct more * Fix some tests * Add more tests * correct more * correct more * correct more * push * correct more * Fix more * Better * without dec mask * correct more * clean * save intermediate * Fix more * Fix VAD for large-v2 * Save new * Correct more * make cleaner * correct tests * correct src * Finish * Fix more * Fix more * finish * Fix edge cases * fix return_dict_in_generate * fix all tests * make style * add docstrings * add docstrings * Fix logit processor * make style * fix pipeline test * fix more style * Apply suggestions from code review * apply feedback Sanchit * correct more * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct more * correct more * correct more * Fix staticmethod * correct more * fix * fix slow tests * make style * fix tokenizer test * fix tokenizer test * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * finish * finish * revert kwargs change --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
d4fc1eb498
commit
690fe73f20
@@ -95,6 +95,7 @@ class LogitsProcessorList(list):
|
|||||||
scores = processor(input_ids, scores, **kwargs)
|
scores = processor(input_ids, scores, **kwargs)
|
||||||
else:
|
else:
|
||||||
scores = processor(input_ids, scores)
|
scores = processor(input_ids, scores)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@@ -1657,6 +1658,9 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
|
|||||||
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
||||||
self.begin_index = begin_index
|
self.begin_index = begin_index
|
||||||
|
|
||||||
|
def set_begin_index(self, begin_index):
|
||||||
|
self.begin_index = begin_index
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if input_ids.shape[1] == self.begin_index:
|
if input_ids.shape[1] == self.begin_index:
|
||||||
@@ -1778,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
|
||||||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
|
||||||
predicting timestamps that are too far in the future.
|
predicting timestamps that are too far in the future.
|
||||||
|
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
|
||||||
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
|
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
@@ -1810,11 +1815,11 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
|
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
|
||||||
): # support for the kwargs
|
): # support for the kwargs
|
||||||
self.eos_token_id = generate_config.eos_token_id
|
|
||||||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
|
||||||
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
|
||||||
|
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
|
||||||
|
|
||||||
# this variable is mostly just used for testing
|
# this variable is mostly just used for testing
|
||||||
self._detect_timestamp_from_logprob = (
|
self._detect_timestamp_from_logprob = (
|
||||||
@@ -1823,10 +1828,17 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.begin_index = (
|
num_forced_ids = (
|
||||||
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
|
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
|
||||||
)
|
)
|
||||||
|
self.begin_index = begin_index or (num_forced_ids + 1)
|
||||||
|
|
||||||
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
|
||||||
|
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
|
||||||
|
# self.max_initial_timestamp_index = 50
|
||||||
|
|
||||||
|
def set_begin_index(self, begin_index):
|
||||||
|
self.begin_index = begin_index
|
||||||
|
|
||||||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
@@ -1878,6 +1890,60 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperNoSpeechDetection(LogitsProcessor):
|
||||||
|
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""
|
||||||
|
|
||||||
|
def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
|
||||||
|
self.no_speech_token = no_speech_token
|
||||||
|
# offset between <start-of-transcription> token, <SOT>, in paper and first generated token
|
||||||
|
# is equal to the position of the first generated token index
|
||||||
|
self.start_of_trans_offset = begin_index
|
||||||
|
|
||||||
|
# `self.begin_index` is a running value that is changed on the fly
|
||||||
|
self.begin_index = begin_index
|
||||||
|
self._no_speech_prob = [0.0]
|
||||||
|
self.is_scores_logprobs = scores_is_logprobs
|
||||||
|
|
||||||
|
# overwritten dynamically
|
||||||
|
self.model = None
|
||||||
|
self.inputs = None
|
||||||
|
|
||||||
|
def set_model(self, model):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def set_inputs(self, inputs):
|
||||||
|
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
|
||||||
|
self.inputs["input_features"] = self.inputs.pop("inputs")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def no_speech_prob(self):
|
||||||
|
return self._no_speech_prob
|
||||||
|
|
||||||
|
def set_begin_index(self, begin_index):
|
||||||
|
self.begin_index = begin_index
|
||||||
|
|
||||||
|
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if input_ids.shape[1] == self.begin_index:
|
||||||
|
if self.start_of_trans_offset > 1:
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(**self.inputs).logits
|
||||||
|
|
||||||
|
no_speech_index = self.begin_index - self.start_of_trans_offset
|
||||||
|
no_speech_scores = logits[:, no_speech_index]
|
||||||
|
else:
|
||||||
|
no_speech_scores = scores
|
||||||
|
|
||||||
|
if self.is_scores_logprobs:
|
||||||
|
probs = no_speech_scores.exp()
|
||||||
|
else:
|
||||||
|
probs = no_speech_scores.float().softmax(dim=-1)
|
||||||
|
|
||||||
|
self._no_speech_prob = probs[:, self.no_speech_token]
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
|
||||||
r"""
|
r"""
|
||||||
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
|
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
|
||||||
|
|||||||
@@ -518,6 +518,8 @@ class GenerationMixin:
|
|||||||
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
|
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
|
||||||
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
|
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
|
||||||
pass
|
pass
|
||||||
|
elif self.config.model_type in ["whisper"]:
|
||||||
|
pass
|
||||||
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
|
||||||
# decoder_attention_mask if provided)
|
# decoder_attention_mask if provided)
|
||||||
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
|
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
|
||||||
|
|||||||
1493
src/transformers/models/whisper/generation_whisper.py
Normal file
1493
src/transformers/models/whisper/generation_whisper.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,10 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" PyTorch Whisper model."""
|
""" PyTorch Whisper model."""
|
||||||
|
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
import warnings
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -27,7 +24,6 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...generation.logits_process import WhisperTimeStampLogitsProcessor
|
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
@@ -47,7 +43,7 @@ from ...utils import (
|
|||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from .configuration_whisper import WhisperConfig
|
from .configuration_whisper import WhisperConfig
|
||||||
from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE
|
from .generation_whisper import WhisperGenerationMixin
|
||||||
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
@@ -231,87 +227,15 @@ def _compute_mask_indices(
|
|||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Applies a median filter of width `filter_width` along the last dimension of the input.
|
|
||||||
|
|
||||||
The `inputs` tensor is assumed to be 3- or 4-dimensional.
|
|
||||||
"""
|
|
||||||
if filter_width <= 0 or filter_width % 2 != 1:
|
|
||||||
raise ValueError("`filter_width` should be an odd number")
|
|
||||||
|
|
||||||
pad_width = filter_width // 2
|
|
||||||
if inputs.shape[-1] <= pad_width:
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
# Pad the left and right edges.
|
|
||||||
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
|
|
||||||
|
|
||||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
|
||||||
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _dynamic_time_warping(matrix: np.ndarray):
|
|
||||||
"""
|
|
||||||
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
|
|
||||||
token-level timestamps.
|
|
||||||
"""
|
|
||||||
output_length, input_length = matrix.shape
|
|
||||||
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
|
|
||||||
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
|
|
||||||
|
|
||||||
cost[0, 0] = 0
|
|
||||||
for j in range(1, input_length + 1):
|
|
||||||
for i in range(1, output_length + 1):
|
|
||||||
c0 = cost[i - 1, j - 1]
|
|
||||||
c1 = cost[i - 1, j]
|
|
||||||
c2 = cost[i, j - 1]
|
|
||||||
|
|
||||||
if c0 < c1 and c0 < c2:
|
|
||||||
c, t = c0, 0
|
|
||||||
elif c1 < c0 and c1 < c2:
|
|
||||||
c, t = c1, 1
|
|
||||||
else:
|
|
||||||
c, t = c2, 2
|
|
||||||
|
|
||||||
cost[i, j] = matrix[i - 1, j - 1] + c
|
|
||||||
trace[i, j] = t
|
|
||||||
|
|
||||||
# backtrace
|
|
||||||
i = trace.shape[0] - 1
|
|
||||||
j = trace.shape[1] - 1
|
|
||||||
trace[0, :] = 2
|
|
||||||
trace[:, 0] = 1
|
|
||||||
|
|
||||||
text_indices = []
|
|
||||||
time_indices = []
|
|
||||||
while i > 0 or j > 0:
|
|
||||||
text_indices.append(i - 1)
|
|
||||||
time_indices.append(j - 1)
|
|
||||||
if trace[i, j] == 0:
|
|
||||||
i -= 1
|
|
||||||
j -= 1
|
|
||||||
elif trace[i, j] == 1:
|
|
||||||
i -= 1
|
|
||||||
elif trace[i, j] == 2:
|
|
||||||
j -= 1
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
|
|
||||||
)
|
|
||||||
|
|
||||||
text_indices = np.array(text_indices)[::-1]
|
|
||||||
time_indices = np.array(time_indices)[::-1]
|
|
||||||
return text_indices, time_indices
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperPositionalEmbedding(nn.Embedding):
|
class WhisperPositionalEmbedding(nn.Embedding):
|
||||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
super().__init__(num_positions, embedding_dim)
|
super().__init__(num_positions, embedding_dim)
|
||||||
|
|
||||||
def forward(self, input_ids, past_key_values_length=0):
|
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
|
||||||
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
if position_ids is None:
|
||||||
|
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
||||||
|
else:
|
||||||
|
return self.weight[position_ids]
|
||||||
|
|
||||||
|
|
||||||
class WhisperAttention(nn.Module):
|
class WhisperAttention(nn.Module):
|
||||||
@@ -1358,6 +1282,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
cross_attn_head_mask=None,
|
cross_attn_head_mask=None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
|
position_ids=None,
|
||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
@@ -1461,9 +1386,13 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(
|
||||||
|
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
|
positions = self.embed_positions(
|
||||||
|
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds + positions
|
hidden_states = inputs_embeds + positions
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
@@ -1645,6 +1574,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
@@ -1703,6 +1633,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
cross_attn_head_mask=cross_attn_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
inputs_embeds=decoder_inputs_embeds,
|
inputs_embeds=decoder_inputs_embeds,
|
||||||
|
position_ids=decoder_position_ids,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1728,7 +1659,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|
||||||
WHISPER_START_DOCSTRING,
|
WHISPER_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
_tied_weights_keys = ["proj_out.weight"]
|
_tied_weights_keys = ["proj_out.weight"]
|
||||||
|
|
||||||
@@ -1776,6 +1707,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
@@ -1830,6 +1762,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
cross_attn_head_mask=cross_attn_head_mask,
|
cross_attn_head_mask=cross_attn_head_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
decoder_inputs_embeds=decoder_inputs_embeds,
|
decoder_inputs_embeds=decoder_inputs_embeds,
|
||||||
|
decoder_position_ids=decoder_position_ids,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
@@ -1860,647 +1793,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
encoder_attentions=outputs.encoder_attentions,
|
encoder_attentions=outputs.encoder_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_features: Optional[torch.Tensor] = None,
|
|
||||||
generation_config=None,
|
|
||||||
logits_processor=None,
|
|
||||||
stopping_criteria=None,
|
|
||||||
prefix_allowed_tokens_fn=None,
|
|
||||||
synced_gpus=False,
|
|
||||||
return_timestamps=None,
|
|
||||||
task=None,
|
|
||||||
language=None,
|
|
||||||
is_multilingual=None,
|
|
||||||
prompt_ids: Optional[torch.Tensor] = None,
|
|
||||||
num_segment_frames: Optional[int] = None,
|
|
||||||
return_token_timestamps: Optional[bool] = None,
|
|
||||||
return_segments: bool = False,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
time_precision: int = 0.02,
|
|
||||||
return_dict_in_generate: Optional[bool] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Transcribes or translates passed mel input features to a sequence of token ids.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
|
|
||||||
model's default generation configuration. You can override any `generation_config` by passing the corresponding
|
|
||||||
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
|
|
||||||
|
|
||||||
For an overview of generation strategies and code examples, check out the [following
|
|
||||||
guide](./generation_strategies).
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
|
|
||||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
|
||||||
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
|
||||||
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
|
||||||
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
|
|
||||||
generation_config (`~generation.GenerationConfig`, *optional*):
|
|
||||||
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
|
||||||
passed to generate matching the attributes of `generation_config` will override them. If
|
|
||||||
`generation_config` is not provided, the default will be used, which had the following loading
|
|
||||||
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
|
||||||
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
|
||||||
default values, whose documentation should be checked to parameterize generation.
|
|
||||||
logits_processor (`LogitsProcessorList`, *optional*):
|
|
||||||
Custom logits processors that complement the default logits processors built from arguments and
|
|
||||||
generation config. If a logit processor is passed that is already created with the arguments or a
|
|
||||||
generation config an error is thrown. This feature is intended for advanced users.
|
|
||||||
stopping_criteria (`StoppingCriteriaList`, *optional*):
|
|
||||||
Custom stopping criteria that complement the default stopping criteria built from arguments and a
|
|
||||||
generation config. If a stopping criteria is passed that is already created with the arguments or a
|
|
||||||
generation config an error is thrown. This feature is intended for advanced users.
|
|
||||||
prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
|
|
||||||
If provided, this function constraints the beam search to allowed tokens only at each step. If not
|
|
||||||
provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
|
|
||||||
`input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
|
|
||||||
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
|
|
||||||
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
|
|
||||||
Retrieval](https://arxiv.org/abs/2010.00904).
|
|
||||||
synced_gpus (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
|
||||||
return_timestamps (`bool`, *optional*):
|
|
||||||
Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`.
|
|
||||||
task (`str`, *optional*):
|
|
||||||
Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids`
|
|
||||||
will be updated accordingly.
|
|
||||||
language (`str`, *optional*):
|
|
||||||
Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can
|
|
||||||
find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary.
|
|
||||||
is_multilingual (`bool`, *optional*):
|
|
||||||
Whether or not the model is multilingual.
|
|
||||||
prompt_ids (`torch.Tensor`, *optional*):
|
|
||||||
Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is
|
|
||||||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
|
||||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
|
||||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
|
||||||
return_token_timestamps (`bool`, *optional*):
|
|
||||||
Whether to return token-level timestamps with the text. This can be used with or without the
|
|
||||||
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
|
||||||
words.
|
|
||||||
return_segments (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to additionally return a list of all segments. Note that this option can only be enabled
|
|
||||||
when doing long-form transcription.
|
|
||||||
attention_mask (`torch.Tensor`, *optional*):
|
|
||||||
`attention_mask` needs to be passed when doing long-form transcription using a batch size > 1.
|
|
||||||
time_precision (`int`, *optional*, defaults to 0.02):
|
|
||||||
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
|
|
||||||
for 20 ms.
|
|
||||||
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens.
|
|
||||||
Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when
|
|
||||||
`return_segments` is set True. In this case the generation outputs of each segment is added to each
|
|
||||||
segment.
|
|
||||||
kwargs (`Dict[str, Any]`, *optional*):
|
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
|
||||||
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
[`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
|
|
||||||
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`.
|
|
||||||
|
|
||||||
If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned.
|
|
||||||
|
|
||||||
else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are:
|
|
||||||
|
|
||||||
- [`~generation.GenerateEncoderDecoderOutput`],
|
|
||||||
- [`~generation.GenerateBeamEncoderDecoderOutput`]
|
|
||||||
|
|
||||||
else only the generated output sequence ids are returned.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
- *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
|
||||||
>>> from datasets import load_dataset, Audio
|
|
||||||
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
>>> model.cuda()
|
|
||||||
|
|
||||||
>>> # load audios > 30 seconds
|
|
||||||
>>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
|
|
||||||
>>> # resample to 16kHz
|
|
||||||
>>> ds = ds.cast_column("audio", Audio(sampling_rate=16000))
|
|
||||||
>>> # take first 8 audios and retrieve array
|
|
||||||
>>> audio = ds[:8]["audio"]
|
|
||||||
>>> audio = [x["array"] for x in audio]
|
|
||||||
|
|
||||||
>>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio
|
|
||||||
>>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
|
|
||||||
>>> inputs = inputs.to("cuda", torch.float32)
|
|
||||||
|
|
||||||
>>> # transcribe audio to ids
|
|
||||||
>>> generated_ids = model.generate(**inputs)
|
|
||||||
|
|
||||||
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
|
||||||
>>> transcription[0]
|
|
||||||
' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!'
|
|
||||||
```
|
|
||||||
|
|
||||||
- *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate.
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
|
||||||
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
|
|
||||||
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
|
|
||||||
>>> input_features = inputs.input_features
|
|
||||||
|
|
||||||
>>> generated_ids = model.generate(inputs=input_features)
|
|
||||||
|
|
||||||
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
||||||
>>> transcription
|
|
||||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
if "inputs" in kwargs:
|
|
||||||
input_features = kwargs.pop("inputs")
|
|
||||||
warnings.warn(
|
|
||||||
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
if generation_config is None:
|
|
||||||
generation_config = copy.deepcopy(self.generation_config)
|
|
||||||
|
|
||||||
return_dict_in_generate = (
|
|
||||||
return_dict_in_generate
|
|
||||||
if return_dict_in_generate is not None
|
|
||||||
else generation_config.return_dict_in_generate
|
|
||||||
)
|
|
||||||
|
|
||||||
input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
|
||||||
if num_segment_frames is None:
|
|
||||||
num_segment_frames = input_stride * self.config.max_source_positions
|
|
||||||
|
|
||||||
# 1. Check whether we're in shortform or longform mode
|
|
||||||
if input_features is not None:
|
|
||||||
total_input_frames = input_features.shape[-1]
|
|
||||||
elif "encoder_outputs" in kwargs:
|
|
||||||
encoder_outputs_shape = (
|
|
||||||
kwargs["encoder_outputs"][0].shape
|
|
||||||
if isinstance(kwargs["encoder_outputs"], BaseModelOutput)
|
|
||||||
else kwargs["encoder_outputs"].shape
|
|
||||||
)
|
|
||||||
total_input_frames = encoder_outputs_shape[1] * input_stride
|
|
||||||
else:
|
|
||||||
raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.")
|
|
||||||
|
|
||||||
is_shortform = total_input_frames <= num_segment_frames
|
|
||||||
|
|
||||||
# 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not
|
|
||||||
if return_timestamps is True:
|
|
||||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
|
||||||
raise ValueError(
|
|
||||||
"You are trying to return timestamps, but the generation config is not properly set. "
|
|
||||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
|
||||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
|
||||||
)
|
|
||||||
generation_config.return_timestamps = return_timestamps
|
|
||||||
elif not is_shortform:
|
|
||||||
if return_timestamps is False:
|
|
||||||
raise ValueError(
|
|
||||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
|
||||||
"requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(generation_config, "no_timestamps_token_id"):
|
|
||||||
raise ValueError(
|
|
||||||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
|
|
||||||
"requires the generation config to have `no_timestamps_token_id` correctly. "
|
|
||||||
"Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. "
|
|
||||||
"For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
|
|
||||||
"or make sure to pass no more than 3000 mel input features."
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Setting `return_timestamps=True` for long-form generation.")
|
|
||||||
generation_config.return_timestamps = True
|
|
||||||
else:
|
|
||||||
generation_config.return_timestamps = False
|
|
||||||
|
|
||||||
# 3. Make sure to correctly set language-related parameters
|
|
||||||
if is_multilingual is not None:
|
|
||||||
if not hasattr(generation_config, "is_multilingual"):
|
|
||||||
raise ValueError(
|
|
||||||
"The generation config is outdated and is thus not compatible with the `is_multilingual` argument "
|
|
||||||
"to `generate`. Please update the generation config as per the instructions "
|
|
||||||
"https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
|
||||||
)
|
|
||||||
generation_config.is_multilingual = is_multilingual
|
|
||||||
|
|
||||||
if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual:
|
|
||||||
if task is not None or language is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot specify `task` or `language` for an English-only model. If the model is intended to be "
|
|
||||||
"multilingual, pass `is_multilingual=True` to generate, or update the generation config."
|
|
||||||
)
|
|
||||||
|
|
||||||
if language is not None:
|
|
||||||
if not hasattr(generation_config, "lang_to_id"):
|
|
||||||
raise ValueError(
|
|
||||||
"The generation config is outdated and is thus not compatible with the `language` argument "
|
|
||||||
"to `generate`. Either set the language using the `forced_decoder_ids` in the model config, "
|
|
||||||
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
|
||||||
)
|
|
||||||
language = language.lower()
|
|
||||||
generation_config.language = language
|
|
||||||
if task is not None:
|
|
||||||
if not hasattr(generation_config, "task_to_id"):
|
|
||||||
raise ValueError(
|
|
||||||
"The generation config is outdated and is thus not compatible with the `task` argument "
|
|
||||||
"to `generate`. Either set the task using the `forced_decoder_ids` in the model config, "
|
|
||||||
"or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224"
|
|
||||||
)
|
|
||||||
generation_config.task = task
|
|
||||||
|
|
||||||
# 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps`
|
|
||||||
forced_decoder_ids = None
|
|
||||||
# Legacy code for backward compatibility
|
|
||||||
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
|
|
||||||
forced_decoder_ids = self.config.forced_decoder_ids
|
|
||||||
elif (
|
|
||||||
hasattr(self.generation_config, "forced_decoder_ids")
|
|
||||||
and self.generation_config.forced_decoder_ids is not None
|
|
||||||
):
|
|
||||||
forced_decoder_ids = self.generation_config.forced_decoder_ids
|
|
||||||
else:
|
|
||||||
forced_decoder_ids = kwargs.get("forced_decoder_ids", None)
|
|
||||||
|
|
||||||
if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None):
|
|
||||||
forced_decoder_ids = []
|
|
||||||
if hasattr(generation_config, "language"):
|
|
||||||
if generation_config.language in generation_config.lang_to_id.keys():
|
|
||||||
language_token = generation_config.language
|
|
||||||
elif generation_config.language in TO_LANGUAGE_CODE.keys():
|
|
||||||
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
|
|
||||||
elif generation_config.language in TO_LANGUAGE_CODE.values():
|
|
||||||
language_token = f"<|{generation_config.language}|>"
|
|
||||||
else:
|
|
||||||
is_language_code = len(generation_config.language) == 2
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported language: {generation_config.language}. Language should be one of:"
|
|
||||||
f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
|
|
||||||
)
|
|
||||||
if language_token not in generation_config.lang_to_id:
|
|
||||||
raise ValueError(
|
|
||||||
f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`."
|
|
||||||
"(You should just add it to the generation config)"
|
|
||||||
)
|
|
||||||
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
|
||||||
else:
|
|
||||||
forced_decoder_ids.append((1, None)) # automatically detect the language
|
|
||||||
|
|
||||||
if hasattr(generation_config, "task"):
|
|
||||||
if generation_config.task in TASK_IDS:
|
|
||||||
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
|
|
||||||
)
|
|
||||||
elif hasattr(generation_config, "task_to_id"):
|
|
||||||
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe
|
|
||||||
if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
|
|
||||||
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
|
||||||
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
|
||||||
|
|
||||||
if forced_decoder_ids is not None:
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
if prompt_ids is not None:
|
|
||||||
if kwargs.get("decoder_start_token_id") is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten."
|
|
||||||
)
|
|
||||||
prompt_ids = prompt_ids.tolist()
|
|
||||||
decoder_start_token_id, *text_prompt_ids = prompt_ids
|
|
||||||
# Slicing the text prompt ids in a manner consistent with the OpenAI implementation
|
|
||||||
# to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599)
|
|
||||||
text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :]
|
|
||||||
# Set the decoder_start_token_id to <|startofprev|>
|
|
||||||
kwargs.update({"decoder_start_token_id": decoder_start_token_id})
|
|
||||||
|
|
||||||
# If the user passes `max_new_tokens`, increase its number to account for the prompt
|
|
||||||
if kwargs.get("max_new_tokens", None) is not None:
|
|
||||||
kwargs["max_new_tokens"] += len(text_prompt_ids)
|
|
||||||
if kwargs["max_new_tokens"] >= self.config.max_target_positions:
|
|
||||||
raise ValueError(
|
|
||||||
f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` "
|
|
||||||
f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced "
|
|
||||||
f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the "
|
|
||||||
f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. "
|
|
||||||
"You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, "
|
|
||||||
f"so that their combined length is less that {self.config.max_target_positions}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reformat the forced_decoder_ids to incorporate the prompt
|
|
||||||
non_prompt_forced_decoder_ids = (
|
|
||||||
kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids
|
|
||||||
)
|
|
||||||
forced_decoder_ids = [
|
|
||||||
*text_prompt_ids,
|
|
||||||
generation_config.decoder_start_token_id,
|
|
||||||
*[token for _rank, token in non_prompt_forced_decoder_ids],
|
|
||||||
]
|
|
||||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)]
|
|
||||||
generation_config.forced_decoder_ids = forced_decoder_ids
|
|
||||||
|
|
||||||
if return_token_timestamps:
|
|
||||||
kwargs["output_attentions"] = True
|
|
||||||
return_dict_in_generate = True
|
|
||||||
kwargs["output_scores"] = True
|
|
||||||
|
|
||||||
if getattr(generation_config, "task", None) == "translate":
|
|
||||||
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
|
||||||
if not hasattr(generation_config, "alignment_heads"):
|
|
||||||
raise ValueError(
|
|
||||||
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
|
||||||
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
|
||||||
)
|
|
||||||
|
|
||||||
if kwargs.get("num_frames") is not None:
|
|
||||||
generation_config.num_frames = kwargs.pop("num_frames")
|
|
||||||
|
|
||||||
if generation_config.return_timestamps is True:
|
|
||||||
last_forced_decoder_ids = (
|
|
||||||
generation_config.forced_decoder_ids[-1][-1]
|
|
||||||
if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id:
|
|
||||||
# remove no_timestamp to be forcefully generated if we want to return timestamps
|
|
||||||
# this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly
|
|
||||||
forced_decoder_ids = generation_config.forced_decoder_ids[:-1]
|
|
||||||
# Make sure that if list is empty we set it to None
|
|
||||||
generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids
|
|
||||||
|
|
||||||
timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
|
||||||
logits_processor = (
|
|
||||||
timestamp_processor if logits_processor is None else timestamp_processor + logits_processor
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
|
|
||||||
if is_shortform:
|
|
||||||
outputs = super().generate(
|
|
||||||
input_features,
|
|
||||||
generation_config,
|
|
||||||
logits_processor,
|
|
||||||
stopping_criteria,
|
|
||||||
prefix_allowed_tokens_fn,
|
|
||||||
synced_gpus,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
|
||||||
num_frames = getattr(generation_config, "num_frames", None)
|
|
||||||
outputs["token_timestamps"] = self._extract_token_timestamps(
|
|
||||||
outputs, generation_config.alignment_heads, num_frames=num_frames
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
# 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated
|
|
||||||
# timestamp tokens
|
|
||||||
# 6.1 Set running parameters for while loop
|
|
||||||
if not return_segments and return_dict_in_generate:
|
|
||||||
raise ValueError(
|
|
||||||
"Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`"
|
|
||||||
)
|
|
||||||
|
|
||||||
# if input is longer than 30 seconds we default to long-form generation
|
|
||||||
timestamp_begin = self.generation_config.no_timestamps_token_id + 1
|
|
||||||
# input stride is mel frames per encoder output vector which is the product of all conv strides
|
|
||||||
batch_size = input_features.shape[0]
|
|
||||||
|
|
||||||
if batch_size > 1 and attention_mask is None:
|
|
||||||
raise ValueError(
|
|
||||||
"When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` "
|
|
||||||
)
|
|
||||||
elif batch_size > 1:
|
|
||||||
max_frames = attention_mask.sum(-1).cpu().to(torch.long)
|
|
||||||
seek = torch.zeros((batch_size,), dtype=torch.long)
|
|
||||||
else:
|
|
||||||
max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames
|
|
||||||
seek = torch.zeros((1,), dtype=torch.long)
|
|
||||||
|
|
||||||
current_segments = [[] for _ in range(batch_size)]
|
|
||||||
cur_to_prev_index_map = list(range(batch_size))
|
|
||||||
|
|
||||||
# batch size can decrease during the run
|
|
||||||
cur_bsz = prev_bsz = batch_size
|
|
||||||
|
|
||||||
# 6.2 Transcribe audio until we reach the end of all input audios
|
|
||||||
while (seek < max_frames).any():
|
|
||||||
prev_bsz = cur_bsz
|
|
||||||
|
|
||||||
# 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop
|
|
||||||
# in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order
|
|
||||||
# to know which original audio is being decoded
|
|
||||||
new_cur_to_prev_index_map = []
|
|
||||||
for i in range(prev_bsz):
|
|
||||||
prev_i = cur_to_prev_index_map[i]
|
|
||||||
if seek[prev_i] >= max_frames[prev_i]:
|
|
||||||
cut_index = i + (cur_bsz - prev_bsz)
|
|
||||||
cur_bsz -= 1
|
|
||||||
input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0)
|
|
||||||
else:
|
|
||||||
# cut out index that goes away
|
|
||||||
new_cur_to_prev_index_map.append(prev_i)
|
|
||||||
|
|
||||||
# 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk
|
|
||||||
cur_to_prev_index_map = new_cur_to_prev_index_map
|
|
||||||
time_offset = seek * time_precision / input_stride
|
|
||||||
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)
|
|
||||||
|
|
||||||
# 6.5 Make sure that all inputs are padded to the same input length
|
|
||||||
segment_input = []
|
|
||||||
for i in range(cur_bsz):
|
|
||||||
prev_i = cur_to_prev_index_map[i]
|
|
||||||
segment_input_slice = input_features[
|
|
||||||
i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]
|
|
||||||
]
|
|
||||||
|
|
||||||
if segment_input_slice.shape[-1] < num_segment_frames:
|
|
||||||
# pad to 3000 if necessary
|
|
||||||
segment_input_slice = F.pad(
|
|
||||||
segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1])
|
|
||||||
)
|
|
||||||
|
|
||||||
segment_input.append(segment_input_slice)
|
|
||||||
|
|
||||||
segment_input = torch.cat(segment_input, dim=0)
|
|
||||||
|
|
||||||
# 6.6 Batch generate current chunk
|
|
||||||
seek_outputs = super().generate(
|
|
||||||
segment_input,
|
|
||||||
generation_config,
|
|
||||||
logits_processor,
|
|
||||||
stopping_criteria,
|
|
||||||
prefix_allowed_tokens_fn,
|
|
||||||
synced_gpus,
|
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
|
||||||
num_frames = getattr(generation_config, "num_frames", None)
|
|
||||||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
|
||||||
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
|
||||||
)
|
|
||||||
|
|
||||||
if return_dict_in_generate:
|
|
||||||
seek_sequences = seek_outputs["sequences"]
|
|
||||||
seek_outputs = [
|
|
||||||
{k: v[i] for k, v in seek_outputs.items()}
|
|
||||||
for i in range(next(iter(seek_outputs.values())).size(0))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
seek_sequences = seek_outputs
|
|
||||||
|
|
||||||
# 6.7 Loop over each decoded audio individually as each decoding can be of a different length
|
|
||||||
for i, seek_sequence in enumerate(seek_sequences):
|
|
||||||
prev_i = cur_to_prev_index_map[i]
|
|
||||||
|
|
||||||
# make sure we cut a predicted EOS token if we are not finished with the generation yet
|
|
||||||
is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i]
|
|
||||||
if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id:
|
|
||||||
seek_sequence = seek_sequence[:-1]
|
|
||||||
|
|
||||||
# remove all padding tokens
|
|
||||||
if seek_sequence[-1] == self.generation_config.pad_token_id:
|
|
||||||
num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum()
|
|
||||||
seek_sequence = seek_sequence[:-num_paddings]
|
|
||||||
|
|
||||||
segments, segment_offset = self._retrieve_segment(
|
|
||||||
seek_sequence=seek_sequence,
|
|
||||||
seek_outputs=seek_outputs,
|
|
||||||
time_offset=time_offset,
|
|
||||||
timestamp_begin=timestamp_begin,
|
|
||||||
seek_num_frames=seek_num_frames,
|
|
||||||
cur_bsz=cur_bsz,
|
|
||||||
time_precision=time_precision,
|
|
||||||
input_stride=input_stride,
|
|
||||||
prev_idx=prev_i,
|
|
||||||
idx=i,
|
|
||||||
)
|
|
||||||
|
|
||||||
current_segments[prev_i] += segments
|
|
||||||
seek[prev_i] += segment_offset
|
|
||||||
|
|
||||||
# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
|
|
||||||
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
|
|
||||||
sequences = []
|
|
||||||
max_total_length = 0
|
|
||||||
for current_segment_list in current_segments:
|
|
||||||
sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1))
|
|
||||||
max_total_length = max(max_total_length, len(sequences[-1]))
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
sequences[i] = F.pad(
|
|
||||||
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
sequences = torch.stack(sequences, dim=0)
|
|
||||||
|
|
||||||
# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
|
|
||||||
if return_segments:
|
|
||||||
return {"sequences": sequences, "segments": current_segments}
|
|
||||||
|
|
||||||
return sequences
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _retrieve_segment(
|
|
||||||
seek_sequence,
|
|
||||||
seek_outputs,
|
|
||||||
time_offset,
|
|
||||||
timestamp_begin,
|
|
||||||
seek_num_frames,
|
|
||||||
cur_bsz,
|
|
||||||
time_precision,
|
|
||||||
input_stride,
|
|
||||||
prev_idx,
|
|
||||||
idx,
|
|
||||||
):
|
|
||||||
# find the predicted "end of segment" predictions of Whisper
|
|
||||||
# "end of segment" predictions occur whenever Whisper predicts a timestamp token
|
|
||||||
timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin)
|
|
||||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]]
|
|
||||||
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
|
||||||
|
|
||||||
# If whisper predicted a "end of segment" via a timestep token, let's go ever each
|
|
||||||
# "end of segment" prediction and slice the decoding into segments accordingly
|
|
||||||
if len(timestamp_segment_indices) > 0:
|
|
||||||
# if the output contains two consecutive timestamp tokens
|
|
||||||
slices = timestamp_segment_indices.tolist()
|
|
||||||
segments = []
|
|
||||||
if single_timestamp_ending:
|
|
||||||
slices.append(len(seek_sequence))
|
|
||||||
|
|
||||||
last_slice = 0
|
|
||||||
# Add each segment to list of all segments
|
|
||||||
for current_slice in slices:
|
|
||||||
sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1]
|
|
||||||
start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin
|
|
||||||
end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin
|
|
||||||
segments.append(
|
|
||||||
{
|
|
||||||
"start": time_offset[prev_idx] + start_timestamp_pos * time_precision,
|
|
||||||
"end": time_offset[prev_idx] + end_timestamp_pos * time_precision,
|
|
||||||
"tokens": sliced_tokens,
|
|
||||||
"result": seek_outputs[idx],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
last_slice = current_slice
|
|
||||||
|
|
||||||
if single_timestamp_ending:
|
|
||||||
# single timestamp at the end means no speech after the last timestamp.
|
|
||||||
segment_offset = seek_num_frames[prev_idx]
|
|
||||||
else:
|
|
||||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
|
||||||
# here we throw away all predictions after the last predicted "end of segment"
|
|
||||||
# since we are cutting right in the middle of an audio
|
|
||||||
last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin
|
|
||||||
segment_offset = last_timestamp_pos * input_stride
|
|
||||||
else:
|
|
||||||
# If whisper does not predict any "end of segment" token, then
|
|
||||||
# the whole decoding is considered a segment and we add it to the list of segments
|
|
||||||
timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()]
|
|
||||||
last_timestamp_pos = seek_num_frames[prev_idx]
|
|
||||||
if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin:
|
|
||||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
|
||||||
last_timestamp_pos = timestamps[-1].item() - timestamp_begin
|
|
||||||
|
|
||||||
segments = [
|
|
||||||
{
|
|
||||||
"start": time_offset[prev_idx],
|
|
||||||
"end": time_offset[prev_idx] + last_timestamp_pos * time_precision,
|
|
||||||
"tokens": seek_sequence,
|
|
||||||
"result": seek_outputs[idx],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
segment_offset = seek_num_frames[prev_idx]
|
|
||||||
|
|
||||||
return segments, segment_offset
|
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
@@ -2508,8 +1800,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
decoder_attention_mask=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
decoder_position_ids = None
|
||||||
|
if decoder_attention_mask is not None:
|
||||||
|
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
|
||||||
|
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
past_length = past_key_values[0][0].shape[2]
|
||||||
|
|
||||||
@@ -2522,12 +1819,16 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
|
|
||||||
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
|
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
||||||
|
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"decoder_attention_mask": None,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"decoder_position_ids": decoder_position_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -2539,99 +1840,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
)
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
|
|
||||||
"""
|
|
||||||
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
|
||||||
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
|
|
||||||
cross-attentions will be cropped before applying DTW.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tensor containing the timestamps in seconds for each predicted token
|
|
||||||
"""
|
|
||||||
# Create a list with `decoder_layers` elements, each a tensor of shape
|
|
||||||
# (batch size, attention_heads, output length, input length).
|
|
||||||
cross_attentions = []
|
|
||||||
for i in range(self.config.decoder_layers):
|
|
||||||
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
|
|
||||||
|
|
||||||
# Select specific cross-attention layers and heads. This is a tensor
|
|
||||||
# of shape (batch size, num selected, output length, input length).
|
|
||||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
|
||||||
weights = weights.permute([1, 0, 2, 3])
|
|
||||||
|
|
||||||
if "beam_indices" in generate_outputs:
|
|
||||||
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
|
||||||
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
|
||||||
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
|
|
||||||
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
|
|
||||||
weights = weights[:, :, :weight_length]
|
|
||||||
|
|
||||||
# If beam index is still -1, it means that the associated token id is EOS
|
|
||||||
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
|
|
||||||
beam_indices = generate_outputs.beam_indices[:, :weight_length]
|
|
||||||
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
|
|
||||||
|
|
||||||
# Select the cross attention from the right beam for each output sequences
|
|
||||||
weights = torch.stack(
|
|
||||||
[
|
|
||||||
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
|
|
||||||
for i in range(beam_indices.shape[1])
|
|
||||||
],
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
|
||||||
batch_size = timestamps.shape[0]
|
|
||||||
|
|
||||||
if num_frames is not None:
|
|
||||||
# two cases:
|
|
||||||
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
|
|
||||||
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
|
|
||||||
|
|
||||||
# we're using np.unique because num_frames can be int/list/tuple
|
|
||||||
if len(np.unique(num_frames)) == 1:
|
|
||||||
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
|
|
||||||
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]
|
|
||||||
|
|
||||||
weights = weights[..., : num_frames // 2]
|
|
||||||
else:
|
|
||||||
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
|
||||||
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
|
||||||
num_frames = np.repeat(num_frames, repeat_time)
|
|
||||||
|
|
||||||
if num_frames is None or isinstance(num_frames, int):
|
|
||||||
# Normalize and smoothen the weights.
|
|
||||||
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
|
|
||||||
mean = torch.mean(weights, dim=-2, keepdim=True)
|
|
||||||
weights = (weights - mean) / std
|
|
||||||
weights = _median_filter(weights, self.config.median_filter_width)
|
|
||||||
|
|
||||||
# Average the different cross-attention heads.
|
|
||||||
weights = weights.mean(dim=1)
|
|
||||||
|
|
||||||
# Perform dynamic time warping on each element of the batch.
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)):
|
|
||||||
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
|
|
||||||
|
|
||||||
# Normalize and smoothen the weights.
|
|
||||||
std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
|
|
||||||
mean = torch.mean(matrix, dim=-2, keepdim=True)
|
|
||||||
matrix = (matrix - mean) / std
|
|
||||||
matrix = _median_filter(matrix, self.config.median_filter_width)
|
|
||||||
|
|
||||||
# Average the different cross-attention heads.
|
|
||||||
matrix = matrix.mean(dim=0)
|
|
||||||
else:
|
|
||||||
matrix = weights[batch_idx]
|
|
||||||
|
|
||||||
text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
|
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
|
||||||
jump_times = time_indices[jumps] * time_precision
|
|
||||||
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
|
||||||
|
|
||||||
return timestamps
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
class WhisperDecoderWrapper(WhisperPreTrainedModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -530,10 +530,21 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
timestamp_begin = self.all_special_ids[-1] + 1
|
timestamp_begin = self.all_special_ids[-1] + 1
|
||||||
outputs = [[]]
|
outputs = [[]]
|
||||||
|
|
||||||
|
cur_max_timestamp = 0.0
|
||||||
|
prev_segments_len = 0.0
|
||||||
|
|
||||||
for token in token_ids:
|
for token in token_ids:
|
||||||
if token >= timestamp_begin:
|
if token >= timestamp_begin:
|
||||||
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
|
timestamp = float((token - timestamp_begin) * time_precision)
|
||||||
outputs.append(timestamp)
|
|
||||||
|
if timestamp < cur_max_timestamp:
|
||||||
|
# next segment has started
|
||||||
|
prev_segments_len += cur_max_timestamp
|
||||||
|
|
||||||
|
cur_max_timestamp = timestamp
|
||||||
|
|
||||||
|
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
|
||||||
outputs.append([])
|
outputs.append([])
|
||||||
else:
|
else:
|
||||||
outputs[-1].append(token)
|
outputs[-1].append(token)
|
||||||
@@ -631,7 +642,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
clean_up_tokenization_spaces: bool = None,
|
clean_up_tokenization_spaces: bool = None,
|
||||||
output_offsets: bool = False,
|
output_offsets: bool = False,
|
||||||
time_precision=0.02,
|
time_precision: float = 0.02,
|
||||||
decode_with_timestamps: bool = False,
|
decode_with_timestamps: bool = False,
|
||||||
normalize: bool = False,
|
normalize: bool = False,
|
||||||
basic_normalize: bool = False,
|
basic_normalize: bool = False,
|
||||||
|
|||||||
@@ -224,10 +224,21 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
"""
|
"""
|
||||||
timestamp_begin = self.all_special_ids[-1] + 1
|
timestamp_begin = self.all_special_ids[-1] + 1
|
||||||
outputs = [[]]
|
outputs = [[]]
|
||||||
|
|
||||||
|
cur_max_timestamp = 0.0
|
||||||
|
prev_segments_len = 0.0
|
||||||
|
|
||||||
for token in token_ids:
|
for token in token_ids:
|
||||||
if token >= timestamp_begin:
|
if token >= timestamp_begin:
|
||||||
timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>"
|
timestamp = float((token - timestamp_begin) * time_precision)
|
||||||
outputs.append(timestamp)
|
|
||||||
|
if timestamp < cur_max_timestamp:
|
||||||
|
# next segment has started
|
||||||
|
prev_segments_len += cur_max_timestamp
|
||||||
|
|
||||||
|
cur_max_timestamp = timestamp
|
||||||
|
|
||||||
|
outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
|
||||||
outputs.append([])
|
outputs.append([])
|
||||||
else:
|
else:
|
||||||
outputs[-1].append(token)
|
outputs[-1].append(token)
|
||||||
@@ -330,7 +341,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
skip_special_tokens: bool = False,
|
skip_special_tokens: bool = False,
|
||||||
clean_up_tokenization_spaces: bool = None,
|
clean_up_tokenization_spaces: bool = None,
|
||||||
output_offsets: bool = False,
|
output_offsets: bool = False,
|
||||||
time_precision=0.02,
|
time_precision: float = 0.02,
|
||||||
decode_with_timestamps: bool = False,
|
decode_with_timestamps: bool = False,
|
||||||
normalize: bool = False,
|
normalize: bool = False,
|
||||||
basic_normalize: bool = False,
|
basic_normalize: bool = False,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1152,7 +1152,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
def test_whisper_longform(self):
|
def test_whisper_longform(self):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!"""
|
EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile."""
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
|||||||
Reference in New Issue
Block a user