add word-level timestamps to Whisper (#23205)
* let's go! * initial implementation of token-level timestamps * only return a single timestamp per token * remove token probabilities * fix return type * fix doc comment * strip special tokens * rename * revert to not stripping special tokens * only support models that have alignment_heads * add integration test * consistently name it token-level timestamps * small DTW tweak * initial support for ASR pipeline * fix pipeline doc comments * resolve token timestamps in pipeline with chunking * change warning when no final timestamp is found * return word-level timestamps * fixup * fix bug that skipped final word in each chunk * fix failing unit tests * merge punctuations into the words * also return word tokens * also return token indices * add (failing) unit test for combine_tokens_into_words * make combine_tokens_into_words private * restore OpenAI's punctuation rules * add pipeline tests * make requested changes * PR review changes * fix failing pipeline test * small stuff from PR * only return words and their timestamps, not segments * move alignment_heads into generation config * forgot to set alignment_heads in pipeline tests * tiny comment fix * grr
This commit is contained in:
committed by
GitHub
parent
0f968ddaa3
commit
cd927a4736
@@ -171,7 +171,9 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time
|
||||||
step, irrespectively of `mask_feature_prob`. Only relevant if
|
step, irrespectively of `mask_feature_prob`. Only relevant if
|
||||||
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
|
`mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`.
|
||||||
|
median_filter_width (`int`, *optional*, defaults to 7):
|
||||||
|
Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps.
|
||||||
|
Should be an odd number.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@@ -229,6 +231,7 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
mask_feature_prob=0.0,
|
mask_feature_prob=0.0,
|
||||||
mask_feature_length=10,
|
mask_feature_length=10,
|
||||||
mask_feature_min_masks=0,
|
mask_feature_min_masks=0,
|
||||||
|
median_filter_width=7,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -265,6 +268,9 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
self.mask_feature_prob = mask_feature_prob
|
self.mask_feature_prob = mask_feature_prob
|
||||||
self.mask_feature_length = mask_feature_length
|
self.mask_feature_length = mask_feature_length
|
||||||
self.mask_feature_min_masks = mask_feature_min_masks
|
self.mask_feature_min_masks = mask_feature_min_masks
|
||||||
|
|
||||||
|
self.median_filter_width = median_filter_width
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
|||||||
@@ -227,6 +227,81 @@ def _compute_mask_indices(
|
|||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Applies a median filter of width `filter_width` along the last dimension of the input.
|
||||||
|
|
||||||
|
The `inputs` tensor is assumed to be 3- or 4-dimensional.
|
||||||
|
"""
|
||||||
|
if filter_width <= 0 or filter_width % 2 != 1:
|
||||||
|
raise ValueError("`filter_width` should be an odd number")
|
||||||
|
|
||||||
|
pad_width = filter_width // 2
|
||||||
|
if inputs.shape[-1] <= pad_width:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
# Pad the left and right edges.
|
||||||
|
inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect")
|
||||||
|
|
||||||
|
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||||
|
result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_time_warping(matrix: np.ndarray):
|
||||||
|
"""
|
||||||
|
Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate
|
||||||
|
token-level timestamps.
|
||||||
|
"""
|
||||||
|
output_length, input_length = matrix.shape
|
||||||
|
cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf
|
||||||
|
trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32)
|
||||||
|
|
||||||
|
cost[0, 0] = 0
|
||||||
|
for j in range(1, input_length + 1):
|
||||||
|
for i in range(1, output_length + 1):
|
||||||
|
c0 = cost[i - 1, j - 1]
|
||||||
|
c1 = cost[i - 1, j]
|
||||||
|
c2 = cost[i, j - 1]
|
||||||
|
|
||||||
|
if c0 < c1 and c0 < c2:
|
||||||
|
c, t = c0, 0
|
||||||
|
elif c1 < c0 and c1 < c2:
|
||||||
|
c, t = c1, 1
|
||||||
|
else:
|
||||||
|
c, t = c2, 2
|
||||||
|
|
||||||
|
cost[i, j] = matrix[i - 1, j - 1] + c
|
||||||
|
trace[i, j] = t
|
||||||
|
|
||||||
|
# backtrace
|
||||||
|
i = trace.shape[0] - 1
|
||||||
|
j = trace.shape[1] - 1
|
||||||
|
trace[0, :] = 2
|
||||||
|
trace[:, 0] = 1
|
||||||
|
|
||||||
|
text_indices = []
|
||||||
|
time_indices = []
|
||||||
|
while i > 0 or j > 0:
|
||||||
|
text_indices.append(i - 1)
|
||||||
|
time_indices.append(j - 1)
|
||||||
|
if trace[i, j] == 0:
|
||||||
|
i -= 1
|
||||||
|
j -= 1
|
||||||
|
elif trace[i, j] == 1:
|
||||||
|
i -= 1
|
||||||
|
elif trace[i, j] == 2:
|
||||||
|
j -= 1
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report."
|
||||||
|
)
|
||||||
|
|
||||||
|
text_indices = np.array(text_indices)[::-1]
|
||||||
|
time_indices = np.array(time_indices)[::-1]
|
||||||
|
return text_indices, time_indices
|
||||||
|
|
||||||
|
|
||||||
class WhisperPositionalEmbedding(nn.Embedding):
|
class WhisperPositionalEmbedding(nn.Embedding):
|
||||||
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
|
||||||
super().__init__(num_positions, embedding_dim)
|
super().__init__(num_positions, embedding_dim)
|
||||||
@@ -1472,6 +1547,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
language=None,
|
language=None,
|
||||||
is_multilingual=None,
|
is_multilingual=None,
|
||||||
prompt_ids: Optional[torch.Tensor] = None,
|
prompt_ids: Optional[torch.Tensor] = None,
|
||||||
|
return_token_timestamps=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -1534,6 +1610,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for
|
||||||
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words
|
||||||
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value.
|
||||||
|
return_token_timestamps (`bool`, *optional*):
|
||||||
|
Whether to return token-level timestamps with the text. This can be used with or without the
|
||||||
|
`return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into
|
||||||
|
words.
|
||||||
kwargs:
|
kwargs:
|
||||||
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
|
||||||
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
||||||
@@ -1662,7 +1742,19 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
if generation_config.return_timestamps:
|
if generation_config.return_timestamps:
|
||||||
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]
|
||||||
|
|
||||||
return super().generate(
|
if return_token_timestamps:
|
||||||
|
kwargs["output_attentions"] = True
|
||||||
|
kwargs["return_dict_in_generate"] = True
|
||||||
|
|
||||||
|
if getattr(generation_config, "task", None) == "translate":
|
||||||
|
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
||||||
|
if not hasattr(generation_config, "alignment_heads"):
|
||||||
|
raise ValueError(
|
||||||
|
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
|
||||||
|
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = super().generate(
|
||||||
inputs,
|
inputs,
|
||||||
generation_config,
|
generation_config,
|
||||||
logits_processor,
|
logits_processor,
|
||||||
@@ -1672,6 +1764,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||||
|
outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self,
|
self,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
@@ -1693,7 +1790,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
"decoder_attention_mask": None,
|
"decoder_attention_mask": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
#
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past_key_values, beam_idx):
|
def _reorder_cache(past_key_values, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
@@ -1701,6 +1797,44 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02):
|
||||||
|
"""
|
||||||
|
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
||||||
|
map each output token to a position in the input audio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor containing the timestamps in seconds for each predicted token
|
||||||
|
"""
|
||||||
|
# Create a list with `decoder_layers` elements, each a tensor of shape
|
||||||
|
# (batch size, attention_heads, output length, input length).
|
||||||
|
cross_attentions = []
|
||||||
|
for i in range(self.config.decoder_layers):
|
||||||
|
cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2))
|
||||||
|
|
||||||
|
# Select specific cross-attention layers and heads. This is a tensor
|
||||||
|
# of shape (batch size, num selected, output length, input length).
|
||||||
|
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||||
|
weights = weights.permute([1, 0, 2, 3])
|
||||||
|
|
||||||
|
# Normalize and smoothen the weights.
|
||||||
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
weights = (weights - mean) / std
|
||||||
|
weights = _median_filter(weights, self.config.median_filter_width)
|
||||||
|
|
||||||
|
# Average the different cross-attention heads.
|
||||||
|
matrix = weights.mean(dim=1)
|
||||||
|
|
||||||
|
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Perform dynamic time warping on each element of the batch.
|
||||||
|
for batch_idx in range(timestamps.shape[0]):
|
||||||
|
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy())
|
||||||
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
|
jump_times = time_indices[jumps] * time_precision
|
||||||
|
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
||||||
|
|
||||||
|
return timestamps
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -585,7 +585,7 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||||
timestamps.
|
timestamps.
|
||||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
WHether or not to decode with timestamps included in the raw text.
|
Whether or not to decode with timestamps included in the raw text.
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The decoded sentence.
|
`str`: The decoded sentence.
|
||||||
"""
|
"""
|
||||||
@@ -779,6 +779,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
time_offset = 0.0
|
time_offset = 0.0
|
||||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||||
previous_tokens = []
|
previous_tokens = []
|
||||||
|
previous_token_timestamps = []
|
||||||
skip = False
|
skip = False
|
||||||
right_stride_start = None
|
right_stride_start = None
|
||||||
|
|
||||||
@@ -788,6 +789,8 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
# We can drop everything to Python list, it's going to make
|
# We can drop everything to Python list, it's going to make
|
||||||
# our lives easier
|
# our lives easier
|
||||||
token_ids = output["tokens"][0].tolist()
|
token_ids = output["tokens"][0].tolist()
|
||||||
|
if return_timestamps == "word":
|
||||||
|
token_timestamps = output["token_timestamps"][0].tolist()
|
||||||
|
|
||||||
# Those keep track of timestamps within strides
|
# Those keep track of timestamps within strides
|
||||||
# Which need to be skipped and resolve all tokens in a single
|
# Which need to be skipped and resolve all tokens in a single
|
||||||
@@ -820,6 +823,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
last_timestamp = token
|
last_timestamp = token
|
||||||
|
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
current_token_timestamps = []
|
||||||
|
|
||||||
# - all tokens within output
|
# - all tokens within output
|
||||||
for i, token in enumerate(token_ids):
|
for i, token in enumerate(token_ids):
|
||||||
@@ -883,20 +887,37 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
chunk["timestamp"][1] = time
|
chunk["timestamp"][1] = time
|
||||||
# Handling merges.
|
# Handling merges.
|
||||||
previous_tokens.append(current_tokens)
|
previous_tokens.append(current_tokens)
|
||||||
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
if return_timestamps == "word":
|
||||||
|
previous_token_timestamps.append(current_token_timestamps)
|
||||||
|
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
|
||||||
|
previous_tokens, previous_token_timestamps
|
||||||
|
)
|
||||||
resolved_text = tokenizer.decode(resolved_tokens)
|
resolved_text = tokenizer.decode(resolved_tokens)
|
||||||
chunk["text"] = resolved_text
|
chunk["text"] = resolved_text
|
||||||
|
if return_timestamps == "word":
|
||||||
|
chunk["words"] = _collate_word_timestamps(
|
||||||
|
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
|
||||||
|
)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
# Flush all our temporary context
|
# Flush all our temporary context
|
||||||
previous_tokens = []
|
previous_tokens = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
previous_token_timestamps = []
|
||||||
|
current_token_timestamps = []
|
||||||
chunk = new_chunk()
|
chunk = new_chunk()
|
||||||
else:
|
else:
|
||||||
# 4/ Regular token
|
# 4/ Regular token
|
||||||
# We just append to the list of all tokens so we can handle
|
# We just append to the list of all tokens so we can handle
|
||||||
# merges later and decode into text.
|
# merges later and decode into text.
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
|
if return_timestamps == "word":
|
||||||
|
start_time = round(token_timestamps[i] + time_offset, 2)
|
||||||
|
if i + 1 < len(token_timestamps):
|
||||||
|
end_time = round(token_timestamps[i + 1] + time_offset, 2)
|
||||||
|
else:
|
||||||
|
end_time = None # should never happen
|
||||||
|
current_token_timestamps.append((start_time, end_time))
|
||||||
|
|
||||||
if "stride" in output:
|
if "stride" in output:
|
||||||
time_offset += chunk_len - stride_right
|
time_offset += chunk_len - stride_right
|
||||||
@@ -904,21 +925,31 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
# Leftover tokens
|
# Leftover tokens
|
||||||
if current_tokens:
|
if current_tokens:
|
||||||
previous_tokens.append(current_tokens)
|
previous_tokens.append(current_tokens)
|
||||||
|
if return_timestamps == "word":
|
||||||
|
previous_token_timestamps.append(current_token_timestamps)
|
||||||
elif not (any(p for p in previous_tokens)):
|
elif not (any(p for p in previous_tokens)):
|
||||||
chunk = new_chunk()
|
chunk = new_chunk()
|
||||||
previous_tokens = []
|
previous_tokens = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
previous_token_timestamps = []
|
||||||
|
current_token_timestamps = []
|
||||||
|
|
||||||
if previous_tokens:
|
if previous_tokens:
|
||||||
if return_timestamps:
|
if return_timestamps:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
|
"Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. "
|
||||||
" WhisperTimeStampLogitsProcessor used?"
|
"Also make sure WhisperTimeStampLogitsProcessor was used during generation."
|
||||||
)
|
)
|
||||||
# Happens when we don't use timestamps
|
# Happens when we don't use timestamps
|
||||||
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
|
||||||
|
previous_tokens, previous_token_timestamps
|
||||||
|
)
|
||||||
resolved_text = tokenizer.decode(resolved_tokens)
|
resolved_text = tokenizer.decode(resolved_tokens)
|
||||||
chunk["text"] = resolved_text
|
chunk["text"] = resolved_text
|
||||||
|
if return_timestamps == "word":
|
||||||
|
chunk["words"] = _collate_word_timestamps(
|
||||||
|
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
|
||||||
|
)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
# Preparing and cleaning up the pipeline output
|
# Preparing and cleaning up the pipeline output
|
||||||
@@ -931,20 +962,35 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
chunk["timestamp"] = tuple(chunk["timestamp"])
|
chunk["timestamp"] = tuple(chunk["timestamp"])
|
||||||
if not return_language:
|
if not return_language:
|
||||||
chunk.pop("language")
|
chunk.pop("language")
|
||||||
optional = {"chunks": chunks}
|
|
||||||
|
if return_timestamps == "word":
|
||||||
|
new_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
new_chunks.extend(chunk["words"])
|
||||||
|
optional = {"chunks": new_chunks}
|
||||||
|
else:
|
||||||
|
optional = {"chunks": chunks}
|
||||||
else:
|
else:
|
||||||
optional = {}
|
optional = {}
|
||||||
return full_text, optional
|
return full_text, optional
|
||||||
|
|
||||||
|
|
||||||
def _find_longest_common_sequence(sequences):
|
def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
|
||||||
# It would be much harder to do O(n) because of fault tolerance.
|
# It would be much harder to do O(n) because of fault tolerance.
|
||||||
# We actually have a really good property which is that the total sequence
|
# We actually have a really good property which is that the total sequence
|
||||||
# MUST be those subsequences in order.
|
# MUST be those subsequences in order.
|
||||||
|
# If token_timestamp_sequences is provided, will split those sequences in
|
||||||
|
# exactly the same way.
|
||||||
|
|
||||||
left_sequence = sequences[0]
|
left_sequence = sequences[0]
|
||||||
left_length = len(left_sequence)
|
left_length = len(left_sequence)
|
||||||
total_sequence = []
|
total_sequence = []
|
||||||
for right_sequence in sequences[1:]:
|
|
||||||
|
if token_timestamp_sequences:
|
||||||
|
left_token_timestamp_sequence = token_timestamp_sequences[0]
|
||||||
|
total_token_timestamp_sequence = []
|
||||||
|
|
||||||
|
for seq_idx, right_sequence in enumerate(sequences[1:]):
|
||||||
# index = 0
|
# index = 0
|
||||||
max_ = 0.0
|
max_ = 0.0
|
||||||
max_indices = (left_length, left_length, 0, 0)
|
max_indices = (left_length, left_length, 0, 0)
|
||||||
@@ -1018,6 +1064,148 @@ def _find_longest_common_sequence(sequences):
|
|||||||
left_sequence = right_sequence[right_mid:]
|
left_sequence = right_sequence[right_mid:]
|
||||||
left_length = len(left_sequence)
|
left_length = len(left_sequence)
|
||||||
|
|
||||||
|
if token_timestamp_sequences:
|
||||||
|
total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid])
|
||||||
|
left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:]
|
||||||
|
|
||||||
total_sequence.extend(left_sequence)
|
total_sequence.extend(left_sequence)
|
||||||
|
|
||||||
return total_sequence
|
if token_timestamp_sequences is None:
|
||||||
|
return total_sequence
|
||||||
|
|
||||||
|
if len(token_timestamp_sequences) > 0:
|
||||||
|
total_token_timestamp_sequence.extend(left_token_timestamp_sequence)
|
||||||
|
return total_sequence, total_token_timestamp_sequence
|
||||||
|
else:
|
||||||
|
return total_sequence, []
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
|
||||||
|
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
|
||||||
|
timings = [
|
||||||
|
{
|
||||||
|
"text": word,
|
||||||
|
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
|
||||||
|
}
|
||||||
|
for word, indices in zip(words, token_indices)
|
||||||
|
]
|
||||||
|
return timings
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_tokens_into_words(
|
||||||
|
tokenizer,
|
||||||
|
tokens: List[int],
|
||||||
|
language: str = None,
|
||||||
|
prepend_punctuations: str = "\"'“¡¿([{-",
|
||||||
|
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id`
|
||||||
|
sequences with the tokens making up each word.
|
||||||
|
"""
|
||||||
|
if language is None:
|
||||||
|
language = tokenizer.language
|
||||||
|
if language is None:
|
||||||
|
language = "english"
|
||||||
|
|
||||||
|
if language in {"chinese", "japanese", "thai", "lao", "myanmar"}:
|
||||||
|
# These languages don't typically use spaces.
|
||||||
|
words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
|
||||||
|
else:
|
||||||
|
words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens)
|
||||||
|
|
||||||
|
_merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations)
|
||||||
|
return words, word_tokens, token_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _split_tokens_on_unicode(tokenizer, tokens: List[int]):
|
||||||
|
"""Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points."""
|
||||||
|
decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True)
|
||||||
|
replacement_char = "\ufffd"
|
||||||
|
|
||||||
|
words = []
|
||||||
|
word_tokens = []
|
||||||
|
token_indices = []
|
||||||
|
current_tokens = []
|
||||||
|
current_indices = []
|
||||||
|
unicode_offset = 0
|
||||||
|
|
||||||
|
for token_idx, token in enumerate(tokens):
|
||||||
|
current_tokens.append(token)
|
||||||
|
current_indices.append(token_idx)
|
||||||
|
decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True)
|
||||||
|
|
||||||
|
if (
|
||||||
|
replacement_char not in decoded
|
||||||
|
or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char
|
||||||
|
):
|
||||||
|
words.append(decoded)
|
||||||
|
word_tokens.append(current_tokens)
|
||||||
|
token_indices.append(current_indices)
|
||||||
|
current_tokens = []
|
||||||
|
current_indices = []
|
||||||
|
unicode_offset += len(decoded)
|
||||||
|
|
||||||
|
return words, word_tokens, token_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _split_tokens_on_spaces(tokenizer, tokens: List[int]):
|
||||||
|
"""Combine tokens into words by splitting at whitespace and punctuation tokens."""
|
||||||
|
subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens)
|
||||||
|
words = []
|
||||||
|
word_tokens = []
|
||||||
|
token_indices = []
|
||||||
|
|
||||||
|
for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list):
|
||||||
|
special = subword_tokens[0] >= tokenizer.eos_token_id
|
||||||
|
with_space = subword.startswith(" ")
|
||||||
|
punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
|
||||||
|
|
||||||
|
if special or with_space or punctuation or len(words) == 0:
|
||||||
|
words.append(subword)
|
||||||
|
word_tokens.append(subword_tokens)
|
||||||
|
token_indices.append(subword_indices)
|
||||||
|
else:
|
||||||
|
words[-1] = words[-1] + subword
|
||||||
|
word_tokens[-1].extend(subword_tokens)
|
||||||
|
token_indices[-1].extend(subword_indices)
|
||||||
|
|
||||||
|
return words, word_tokens, token_indices
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_punctuations(words, tokens, indices, prepended, appended):
|
||||||
|
"""Merges punctuation tokens with neighboring words."""
|
||||||
|
# prepend punctuations
|
||||||
|
i = len(words) - 2
|
||||||
|
j = len(words) - 1
|
||||||
|
while i >= 0:
|
||||||
|
if words[i].startswith(" ") and words[i].strip() in prepended:
|
||||||
|
words[j] = words[i] + words[j]
|
||||||
|
tokens[j] = tokens[i] + tokens[j]
|
||||||
|
indices[j] = indices[i] + indices[j]
|
||||||
|
words[i] = ""
|
||||||
|
tokens[i] = []
|
||||||
|
indices[i] = []
|
||||||
|
else:
|
||||||
|
j = i
|
||||||
|
i -= 1
|
||||||
|
|
||||||
|
# append punctuations
|
||||||
|
i = 0
|
||||||
|
j = 1
|
||||||
|
while j < len(words):
|
||||||
|
if not words[i].endswith(" ") and words[j] in appended:
|
||||||
|
words[i] += words[j]
|
||||||
|
tokens[i] += tokens[j]
|
||||||
|
indices[i] += indices[j]
|
||||||
|
words[j] = ""
|
||||||
|
tokens[j] = []
|
||||||
|
indices[j] = []
|
||||||
|
else:
|
||||||
|
i = j
|
||||||
|
j += 1
|
||||||
|
|
||||||
|
# remove elements that are now empty
|
||||||
|
words[:] = [word for word in words if word]
|
||||||
|
tokens[:] = [token for token in tokens if token]
|
||||||
|
indices[:] = [idx for idx in indices if idx]
|
||||||
|
|||||||
@@ -295,7 +295,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
Whether or not to output the offsets of the tokens. This should only be set if the model predicted
|
||||||
timestamps.
|
timestamps.
|
||||||
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
decode_with_timestamps (`bool`, *optional*, defaults to `False`):
|
||||||
WHether or not to decode with timestamps included in the raw text.
|
Whether or not to decode with timestamps included in the raw text.
|
||||||
Returns:
|
Returns:
|
||||||
`str`: The decoded sentence.
|
`str`: The decoded sentence.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -246,12 +246,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
|
||||||
inference to provide more context to the model). Only use `stride` with CTC models.
|
inference to provide more context to the model). Only use `stride` with CTC models.
|
||||||
return_timestamps (*optional*, `str`):
|
return_timestamps (*optional*, `str`):
|
||||||
Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the
|
Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the
|
||||||
text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6),
|
text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)},
|
||||||
{"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was
|
{"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was
|
||||||
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
|
pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return
|
||||||
`timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ",
|
||||||
"timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model
|
"timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model
|
||||||
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds.
|
||||||
generate_kwargs (`dict`, *optional*):
|
generate_kwargs (`dict`, *optional*):
|
||||||
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
|
||||||
@@ -265,8 +265,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
- **text** (`str` ) -- The recognized text.
|
- **text** (`str` ) -- The recognized text.
|
||||||
- **chunks** (*optional(, `List[Dict]`)
|
- **chunks** (*optional(, `List[Dict]`)
|
||||||
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
||||||
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
|
||||||
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
"there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
||||||
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
||||||
"""
|
"""
|
||||||
return super().__call__(inputs, **kwargs)
|
return super().__call__(inputs, **kwargs)
|
||||||
@@ -421,6 +421,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
generate_kwargs = {}
|
generate_kwargs = {}
|
||||||
if return_timestamps and self.type == "seq2seq_whisper":
|
if return_timestamps and self.type == "seq2seq_whisper":
|
||||||
generate_kwargs["return_timestamps"] = return_timestamps
|
generate_kwargs["return_timestamps"] = return_timestamps
|
||||||
|
if return_timestamps == "word":
|
||||||
|
generate_kwargs["return_token_timestamps"] = True
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
|
|
||||||
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
if self.type in {"seq2seq", "seq2seq_whisper"}:
|
||||||
@@ -447,7 +449,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
out = {"tokens": tokens}
|
if return_timestamps == "word" and self.type == "seq2seq_whisper":
|
||||||
|
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
||||||
|
else:
|
||||||
|
out = {"tokens": tokens}
|
||||||
if self.type == "seq2seq_whisper":
|
if self.type == "seq2seq_whisper":
|
||||||
stride = model_inputs.pop("stride", None)
|
stride = model_inputs.pop("stride", None)
|
||||||
if stride is not None:
|
if stride is not None:
|
||||||
@@ -486,9 +491,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if return_timestamps and self.type == "seq2seq":
|
if return_timestamps and self.type == "seq2seq":
|
||||||
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
|
raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !")
|
||||||
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
if return_timestamps == "char" and self.type == "ctc_with_lm":
|
||||||
raise ValueError("CTC with LM cannot return `char` timestamps, only `words`")
|
raise ValueError("CTC with LM cannot return `char` timestamps, only `word`")
|
||||||
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
|
if return_timestamps == "char" and self.type == "seq2seq_whisper":
|
||||||
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
|
raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.")
|
||||||
|
|
||||||
if return_language is not None and self.type != "seq2seq_whisper":
|
if return_language is not None and self.type != "seq2seq_whisper":
|
||||||
raise ValueError("Only whisper can return language for now.")
|
raise ValueError("Only whisper can return language for now.")
|
||||||
@@ -574,6 +579,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
output.pop("logits", None)
|
output.pop("logits", None)
|
||||||
output.pop("is_last", None)
|
output.pop("is_last", None)
|
||||||
output.pop("stride", None)
|
output.pop("stride", None)
|
||||||
|
output.pop("token_timestamps", None)
|
||||||
for k, v in output.items():
|
for k, v in output.items():
|
||||||
extra[k].append(v)
|
extra[k].append(v)
|
||||||
return {"text": text, **optional, **extra}
|
return {"text": text, **optional, **extra}
|
||||||
|
|||||||
@@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tiny_token_timestamp_generation(self):
|
||||||
|
set_seed(0)
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(4)
|
||||||
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_outputs = model.generate(
|
||||||
|
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
|
||||||
|
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
|
||||||
|
])
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_specaugment_librispeech(self):
|
def test_tiny_specaugment_librispeech(self):
|
||||||
torch_device = "cpu"
|
torch_device = "cpu"
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||||
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
|
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
from ...test_tokenization_common import TokenizerTesterMixin
|
from ...test_tokenization_common import TokenizerTesterMixin
|
||||||
@@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
||||||
|
|
||||||
|
def test_combine_tokens_into_words(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
rust_tokenizer = self.get_rust_tokenizer()
|
||||||
|
|
||||||
|
# 'whatever "whatever" said someone, clever!?'
|
||||||
|
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
|
||||||
|
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
|
||||||
|
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
|
||||||
|
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
|
||||||
|
output = _combine_tokens_into_words(tokenizer, encoded_input)
|
||||||
|
self.assertEqual(expected_words, output[0])
|
||||||
|
self.assertEqual(expected_tokens, output[1])
|
||||||
|
self.assertEqual(expected_indices, output[2])
|
||||||
|
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
|
||||||
|
self.assertEqual(expected_words, output_rust[0])
|
||||||
|
self.assertEqual(expected_tokens, output_rust[1])
|
||||||
|
self.assertEqual(expected_indices, output_rust[2])
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|||||||
@@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||||
|
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||||
|
# fmt: off
|
||||||
|
# Note that the word-level timestamps predicted here are pretty bad.
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
|
"chunks": [
|
||||||
|
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
|
||||||
|
{'text': ' returned', 'timestamp': (29.9, 29.9)},
|
||||||
|
{'text': ' to', 'timestamp': (29.9, 29.9)},
|
||||||
|
{'text': ' its', 'timestamp': (29.9, 29.9)},
|
||||||
|
{'text': ' place', 'timestamp': (29.9, 29.9)},
|
||||||
|
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
|
||||||
|
{'text': ' the', 'timestamp': (29.9, 29.9)},
|
||||||
|
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
@@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||||
|
output = speech_recognizer(filename, return_timestamps="word")
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
output,
|
||||||
|
{
|
||||||
|
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||||
|
"chunks": [
|
||||||
|
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
|
||||||
|
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
|
||||||
|
{'text': ' is', 'timestamp': (1.18, 1.44)},
|
||||||
|
{'text': ' the', 'timestamp': (1.44, 1.58)},
|
||||||
|
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
|
||||||
|
{'text': ' of', 'timestamp': (1.98, 2.3)},
|
||||||
|
{'text': ' the', 'timestamp': (2.3, 2.46)},
|
||||||
|
{'text': ' middle', 'timestamp': (2.46, 2.56)},
|
||||||
|
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
|
||||||
|
{'text': ' and', 'timestamp': (3.38, 3.52)},
|
||||||
|
{'text': ' we', 'timestamp': (3.52, 3.6)},
|
||||||
|
{'text': ' are', 'timestamp': (3.6, 3.72)},
|
||||||
|
{'text': ' glad', 'timestamp': (3.72, 4.0)},
|
||||||
|
{'text': ' to', 'timestamp': (4.0, 4.26)},
|
||||||
|
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
|
||||||
|
{'text': ' his', 'timestamp': (4.54, 4.92)},
|
||||||
|
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user