From cd927a47361c95c8adc5037df6453b65cca9149f Mon Sep 17 00:00:00 2001 From: Matthijs Hollemans Date: Wed, 21 Jun 2023 17:48:21 +0200 Subject: [PATCH] add word-level timestamps to Whisper (#23205) * let's go! * initial implementation of token-level timestamps * only return a single timestamp per token * remove token probabilities * fix return type * fix doc comment * strip special tokens * rename * revert to not stripping special tokens * only support models that have alignment_heads * add integration test * consistently name it token-level timestamps * small DTW tweak * initial support for ASR pipeline * fix pipeline doc comments * resolve token timestamps in pipeline with chunking * change warning when no final timestamp is found * return word-level timestamps * fixup * fix bug that skipped final word in each chunk * fix failing unit tests * merge punctuations into the words * also return word tokens * also return token indices * add (failing) unit test for combine_tokens_into_words * make combine_tokens_into_words private * restore OpenAI's punctuation rules * add pipeline tests * make requested changes * PR review changes * fix failing pipeline test * small stuff from PR * only return words and their timestamps, not segments * move alignment_heads into generation config * forgot to set alignment_heads in pipeline tests * tiny comment fix * grr --- .../models/whisper/configuration_whisper.py | 8 +- .../models/whisper/modeling_whisper.py | 138 +++++++++++- .../models/whisper/tokenization_whisper.py | 206 +++++++++++++++++- .../whisper/tokenization_whisper_fast.py | 2 +- .../pipelines/automatic_speech_recognition.py | 28 ++- tests/models/whisper/test_modeling_whisper.py | 29 +++ .../whisper/test_tokenization_whisper.py | 20 +- ..._pipelines_automatic_speech_recognition.py | 50 +++++ 8 files changed, 456 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index c95e5b2887..a8bbc9718f 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -171,7 +171,9 @@ class WhisperConfig(PretrainedConfig): The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time step, irrespectively of `mask_feature_prob`. Only relevant if `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`. - + median_filter_width (`int`, *optional*, defaults to 7): + Width of the median filter used to smoothen to cross-attention outputs when computing token timestamps. + Should be an odd number. Example: @@ -229,6 +231,7 @@ class WhisperConfig(PretrainedConfig): mask_feature_prob=0.0, mask_feature_length=10, mask_feature_min_masks=0, + median_filter_width=7, **kwargs, ): self.vocab_size = vocab_size @@ -265,6 +268,9 @@ class WhisperConfig(PretrainedConfig): self.mask_feature_prob = mask_feature_prob self.mask_feature_length = mask_feature_length self.mask_feature_min_masks = mask_feature_min_masks + + self.median_filter_width = median_filter_width + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 703607ad4a..ef6a98b6c5 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -227,6 +227,81 @@ def _compute_mask_indices( return spec_aug_mask +def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor: + """ + Applies a median filter of width `filter_width` along the last dimension of the input. + + The `inputs` tensor is assumed to be 3- or 4-dimensional. + """ + if filter_width <= 0 or filter_width % 2 != 1: + raise ValueError("`filter_width` should be an odd number") + + pad_width = filter_width // 2 + if inputs.shape[-1] <= pad_width: + return inputs + + # Pad the left and right edges. + inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect") + + # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) + result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width] + return result + + +def _dynamic_time_warping(matrix: np.ndarray): + """ + Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate + token-level timestamps. + """ + output_length, input_length = matrix.shape + cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf + trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32) + + cost[0, 0] = 0 + for j in range(1, input_length + 1): + for i in range(1, output_length + 1): + c0 = cost[i - 1, j - 1] + c1 = cost[i - 1, j] + c2 = cost[i, j - 1] + + if c0 < c1 and c0 < c2: + c, t = c0, 0 + elif c1 < c0 and c1 < c2: + c, t = c1, 1 + else: + c, t = c2, 2 + + cost[i, j] = matrix[i - 1, j - 1] + c + trace[i, j] = t + + # backtrace + i = trace.shape[0] - 1 + j = trace.shape[1] - 1 + trace[0, :] = 2 + trace[:, 0] = 1 + + text_indices = [] + time_indices = [] + while i > 0 or j > 0: + text_indices.append(i - 1) + time_indices.append(j - 1) + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise RuntimeError( + f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report." + ) + + text_indices = np.array(text_indices)[::-1] + time_indices = np.array(time_indices)[::-1] + return text_indices, time_indices + + class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) @@ -1472,6 +1547,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): language=None, is_multilingual=None, prompt_ids: Optional[torch.Tensor] = None, + return_token_timestamps=None, **kwargs, ): """ @@ -1534,6 +1610,10 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. + return_token_timestamps (`bool`, *optional*): + Whether to return token-level timestamps with the text. This can be used with or without the + `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into + words. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -1662,7 +1742,19 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): if generation_config.return_timestamps: logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)] - return super().generate( + if return_token_timestamps: + kwargs["output_attentions"] = True + kwargs["return_dict_in_generate"] = True + + if getattr(generation_config, "task", None) == "translate": + logger.warning("Token-level timestamps may not be reliable for task 'translate'.") + if not hasattr(generation_config, "alignment_heads"): + raise ValueError( + "Model generation config has no `alignment_heads`, token-level timestamps not available. " + "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." + ) + + outputs = super().generate( inputs, generation_config, logits_processor, @@ -1672,6 +1764,11 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): **kwargs, ) + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + outputs["token_timestamps"] = self._extract_token_timestamps(outputs, generation_config.alignment_heads) + + return outputs + def prepare_inputs_for_generation( self, decoder_input_ids, @@ -1693,7 +1790,6 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): "decoder_attention_mask": None, } - # @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () @@ -1701,6 +1797,44 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past + def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02): + """ + Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to + map each output token to a position in the input audio. + + Returns: + tensor containing the timestamps in seconds for each predicted token + """ + # Create a list with `decoder_layers` elements, each a tensor of shape + # (batch size, attention_heads, output length, input length). + cross_attentions = [] + for i in range(self.config.decoder_layers): + cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) + + # Select specific cross-attention layers and heads. This is a tensor + # of shape (batch size, num selected, output length, input length). + weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) + weights = weights.permute([1, 0, 2, 3]) + + # Normalize and smoothen the weights. + std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) + weights = (weights - mean) / std + weights = _median_filter(weights, self.config.median_filter_width) + + # Average the different cross-attention heads. + matrix = weights.mean(dim=1) + + timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32) + + # Perform dynamic time warping on each element of the batch. + for batch_idx in range(timestamps.shape[0]): + text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy()) + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] * time_precision + timestamps[batch_idx, 1:] = torch.tensor(jump_times) + + return timestamps + @add_start_docstrings( """ diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 52a2e2dd72..6053f479aa 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -585,7 +585,7 @@ class WhisperTokenizer(PreTrainedTokenizer): Whether or not to output the offsets of the tokens. This should only be set if the model predicted timestamps. decode_with_timestamps (`bool`, *optional*, defaults to `False`): - WHether or not to decode with timestamps included in the raw text. + Whether or not to decode with timestamps included in the raw text. Returns: `str`: The decoded sentence. """ @@ -779,6 +779,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_offset = 0.0 timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 previous_tokens = [] + previous_token_timestamps = [] skip = False right_stride_start = None @@ -788,6 +789,8 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, # We can drop everything to Python list, it's going to make # our lives easier token_ids = output["tokens"][0].tolist() + if return_timestamps == "word": + token_timestamps = output["token_timestamps"][0].tolist() # Those keep track of timestamps within strides # Which need to be skipped and resolve all tokens in a single @@ -820,6 +823,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, last_timestamp = token current_tokens = [] + current_token_timestamps = [] # - all tokens within output for i, token in enumerate(token_ids): @@ -883,20 +887,37 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, chunk["timestamp"][1] = time # Handling merges. previous_tokens.append(current_tokens) - resolved_tokens = _find_longest_common_sequence(previous_tokens) + if return_timestamps == "word": + previous_token_timestamps.append(current_token_timestamps) + resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence( + previous_tokens, previous_token_timestamps + ) resolved_text = tokenizer.decode(resolved_tokens) chunk["text"] = resolved_text + if return_timestamps == "word": + chunk["words"] = _collate_word_timestamps( + tokenizer, resolved_tokens, resolved_token_timestamps, last_language + ) chunks.append(chunk) # Flush all our temporary context previous_tokens = [] current_tokens = [] + previous_token_timestamps = [] + current_token_timestamps = [] chunk = new_chunk() else: # 4/ Regular token # We just append to the list of all tokens so we can handle # merges later and decode into text. current_tokens.append(token) + if return_timestamps == "word": + start_time = round(token_timestamps[i] + time_offset, 2) + if i + 1 < len(token_timestamps): + end_time = round(token_timestamps[i + 1] + time_offset, 2) + else: + end_time = None # should never happen + current_token_timestamps.append((start_time, end_time)) if "stride" in output: time_offset += chunk_len - stride_right @@ -904,21 +925,31 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, # Leftover tokens if current_tokens: previous_tokens.append(current_tokens) + if return_timestamps == "word": + previous_token_timestamps.append(current_token_timestamps) elif not (any(p for p in previous_tokens)): chunk = new_chunk() previous_tokens = [] current_tokens = [] + previous_token_timestamps = [] + current_token_timestamps = [] if previous_tokens: if return_timestamps: logger.warning( - "There was an error while processing timestamps, we haven't found a timestamp as last token. Was" - " WhisperTimeStampLogitsProcessor used?" + "Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. " + "Also make sure WhisperTimeStampLogitsProcessor was used during generation." ) # Happens when we don't use timestamps - resolved_tokens = _find_longest_common_sequence(previous_tokens) + resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence( + previous_tokens, previous_token_timestamps + ) resolved_text = tokenizer.decode(resolved_tokens) chunk["text"] = resolved_text + if return_timestamps == "word": + chunk["words"] = _collate_word_timestamps( + tokenizer, resolved_tokens, resolved_token_timestamps, last_language + ) chunks.append(chunk) # Preparing and cleaning up the pipeline output @@ -931,20 +962,35 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, chunk["timestamp"] = tuple(chunk["timestamp"]) if not return_language: chunk.pop("language") - optional = {"chunks": chunks} + + if return_timestamps == "word": + new_chunks = [] + for chunk in chunks: + new_chunks.extend(chunk["words"]) + optional = {"chunks": new_chunks} + else: + optional = {"chunks": chunks} else: optional = {} return full_text, optional -def _find_longest_common_sequence(sequences): +def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): # It would be much harder to do O(n) because of fault tolerance. # We actually have a really good property which is that the total sequence # MUST be those subsequences in order. + # If token_timestamp_sequences is provided, will split those sequences in + # exactly the same way. + left_sequence = sequences[0] left_length = len(left_sequence) total_sequence = [] - for right_sequence in sequences[1:]: + + if token_timestamp_sequences: + left_token_timestamp_sequence = token_timestamp_sequences[0] + total_token_timestamp_sequence = [] + + for seq_idx, right_sequence in enumerate(sequences[1:]): # index = 0 max_ = 0.0 max_indices = (left_length, left_length, 0, 0) @@ -1018,6 +1064,148 @@ def _find_longest_common_sequence(sequences): left_sequence = right_sequence[right_mid:] left_length = len(left_sequence) + if token_timestamp_sequences: + total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid]) + left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:] + total_sequence.extend(left_sequence) - return total_sequence + if token_timestamp_sequences is None: + return total_sequence + + if len(token_timestamp_sequences) > 0: + total_token_timestamp_sequence.extend(left_token_timestamp_sequence) + return total_sequence, total_token_timestamp_sequence + else: + return total_sequence, [] + + +def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language): + words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language) + timings = [ + { + "text": word, + "timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]), + } + for word, indices in zip(words, token_indices) + ] + return timings + + +def _combine_tokens_into_words( + tokenizer, + tokens: List[int], + language: str = None, + prepend_punctuations: str = "\"'“¡¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", +): + """ + Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id` + sequences with the tokens making up each word. + """ + if language is None: + language = tokenizer.language + if language is None: + language = "english" + + if language in {"chinese", "japanese", "thai", "lao", "myanmar"}: + # These languages don't typically use spaces. + words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens) + else: + words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens) + + _merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations) + return words, word_tokens, token_indices + + +def _split_tokens_on_unicode(tokenizer, tokens: List[int]): + """Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points.""" + decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True) + replacement_char = "\ufffd" + + words = [] + word_tokens = [] + token_indices = [] + current_tokens = [] + current_indices = [] + unicode_offset = 0 + + for token_idx, token in enumerate(tokens): + current_tokens.append(token) + current_indices.append(token_idx) + decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True) + + if ( + replacement_char not in decoded + or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char + ): + words.append(decoded) + word_tokens.append(current_tokens) + token_indices.append(current_indices) + current_tokens = [] + current_indices = [] + unicode_offset += len(decoded) + + return words, word_tokens, token_indices + + +def _split_tokens_on_spaces(tokenizer, tokens: List[int]): + """Combine tokens into words by splitting at whitespace and punctuation tokens.""" + subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens) + words = [] + word_tokens = [] + token_indices = [] + + for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list): + special = subword_tokens[0] >= tokenizer.eos_token_id + with_space = subword.startswith(" ") + punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + token_indices.append(subword_indices) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + token_indices[-1].extend(subword_indices) + + return words, word_tokens, token_indices + + +def _merge_punctuations(words, tokens, indices, prepended, appended): + """Merges punctuation tokens with neighboring words.""" + # prepend punctuations + i = len(words) - 2 + j = len(words) - 1 + while i >= 0: + if words[i].startswith(" ") and words[i].strip() in prepended: + words[j] = words[i] + words[j] + tokens[j] = tokens[i] + tokens[j] + indices[j] = indices[i] + indices[j] + words[i] = "" + tokens[i] = [] + indices[i] = [] + else: + j = i + i -= 1 + + # append punctuations + i = 0 + j = 1 + while j < len(words): + if not words[i].endswith(" ") and words[j] in appended: + words[i] += words[j] + tokens[i] += tokens[j] + indices[i] += indices[j] + words[j] = "" + tokens[j] = [] + indices[j] = [] + else: + i = j + j += 1 + + # remove elements that are now empty + words[:] = [word for word in words if word] + tokens[:] = [token for token in tokens if token] + indices[:] = [idx for idx in indices if idx] diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index a8cd8d6627..4861de6528 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -295,7 +295,7 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): Whether or not to output the offsets of the tokens. This should only be set if the model predicted timestamps. decode_with_timestamps (`bool`, *optional*, defaults to `False`): - WHether or not to decode with timestamps included in the raw text. + Whether or not to decode with timestamps included in the raw text. Returns: `str`: The decoded sentence. """ diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index b96363232f..7b23fb5278 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -246,12 +246,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): treat the first `left` samples and last `right` samples to be ignored in decoding (but used at inference to provide more context to the model). Only use `stride` with CTC models. return_timestamps (*optional*, `str`): - Only available for pure CTC models. If set to `"char"`, the pipeline will return `timestamps` along the - text for every character in the text. For instance if you get `[{"text": "h", "timestamps": (0.5,0.6), - {"text": "i", "timestamps": (0.7, .9)}]`, then it means the model predicts that the letter "h" was + Only available for pure CTC models. If set to `"char"`, the pipeline will return timestamps along the + text for every character in the text. For instance if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, + {"text": "i", "timestamp": (0.7, 0.9)}]`, then it means the model predicts that the letter "h" was pronounced after `0.5` and before `0.6` seconds. If set to `"word"`, the pipeline will return - `timestamps` along the text for every word in the text. For instance if you get `[{"text": "hi ", - "timestamps": (0.5,0.9), {"text": "there", "timestamps": (1.0, .1.5)}]`, then it means the model + timestamps along the text for every word in the text. For instance if you get `[{"text": "hi ", + "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp": (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was pronounced after `0.5` and before `0.9` seconds. generate_kwargs (`dict`, *optional*): The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a @@ -265,8 +265,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): - **text** (`str` ) -- The recognized text. - **chunks** (*optional(, `List[Dict]`) When using `return_timestamps`, the `chunks` will become a list containing all the various text - chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text": - "there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing + chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": + "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing `"".join(chunk["text"] for chunk in output["chunks"])`. """ return super().__call__(inputs, **kwargs) @@ -421,6 +421,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): generate_kwargs = {} if return_timestamps and self.type == "seq2seq_whisper": generate_kwargs["return_timestamps"] = return_timestamps + if return_timestamps == "word": + generate_kwargs["return_token_timestamps"] = True is_last = model_inputs.pop("is_last") if self.type in {"seq2seq", "seq2seq_whisper"}: @@ -447,7 +449,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): attention_mask=attention_mask, **generate_kwargs, ) - out = {"tokens": tokens} + if return_timestamps == "word" and self.type == "seq2seq_whisper": + out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]} + else: + out = {"tokens": tokens} if self.type == "seq2seq_whisper": stride = model_inputs.pop("stride", None) if stride is not None: @@ -486,9 +491,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if return_timestamps and self.type == "seq2seq": raise ValueError("We cannot return_timestamps yet on non-ctc models apart from Whisper !") if return_timestamps == "char" and self.type == "ctc_with_lm": - raise ValueError("CTC with LM cannot return `char` timestamps, only `words`") - if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper": - raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.") + raise ValueError("CTC with LM cannot return `char` timestamps, only `word`") + if return_timestamps == "char" and self.type == "seq2seq_whisper": + raise ValueError("Whisper cannot return `char` timestamps, use `True` or `word` instead.") if return_language is not None and self.type != "seq2seq_whisper": raise ValueError("Only whisper can return language for now.") @@ -574,6 +579,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): output.pop("logits", None) output.pop("is_last", None) output.pop("stride", None) + output.pop("token_timestamps", None) for k, v in output.items(): extra[k].append(v) return {"text": text, **optional, **extra} diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 2bb34aa1af..07fd5dc2ba 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + @slow + def test_tiny_token_timestamp_generation(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model.to(torch_device) + + input_speech = self._load_datasamples(4) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to( + torch_device + ) + + generate_outputs = model.generate( + input_features, max_length=448, return_timestamps=True, return_token_timestamps=True + ) + + self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape) + + # fmt: off + EXPECTED_OUTPUT = torch.tensor([ + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600] + ]) + # fmt: on + + self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT)) + @slow def test_tiny_specaugment_librispeech(self): torch_device = "cpu" diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index aea86525f5..403c5fe9e8 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -15,7 +15,7 @@ import unittest from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast -from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence +from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence from transformers.testing_utils import slow from ...test_tokenization_common import TokenizerTesterMixin @@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist()) + def test_combine_tokens_into_words(self): + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + # 'whatever "whatever" said someone, clever!?' + encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323] + expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"] + expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]] + expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]] + output = _combine_tokens_into_words(tokenizer, encoded_input) + self.assertEqual(expected_words, output[0]) + self.assertEqual(expected_tokens, output[1]) + self.assertEqual(expected_indices, output[2]) + output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input) + self.assertEqual(expected_words, output_rust[0]) + self.assertEqual(expected_tokens, output_rust[1]) + self.assertEqual(expected_indices, output_rust[2]) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en" diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 952508dca4..2b77c556f4 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): "chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}], }, ) + pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] + res = pipe(sample["audio"]["array"], return_timestamps="word") + # fmt: off + # Note that the word-level timestamps predicted here are pretty bad. + self.assertEqual( + res, + { + "text": " Conquered returned to its place amidst the tents.", + "chunks": [ + {'text': ' Conquered', 'timestamp': (29.78, 29.9)}, + {'text': ' returned', 'timestamp': (29.9, 29.9)}, + {'text': ' to', 'timestamp': (29.9, 29.9)}, + {'text': ' its', 'timestamp': (29.9, 29.9)}, + {'text': ' place', 'timestamp': (29.9, 29.9)}, + {'text': ' amidst', 'timestamp': (29.9, 29.9)}, + {'text': ' the', 'timestamp': (29.9, 29.9)}, + {'text': ' tents.', 'timestamp': (29.9, 29.9)} + ] + } + ) + # fmt: on @require_torch @slow @@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ], }, ) + speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]] + output = speech_recognizer(filename, return_timestamps="word") + # fmt: off + self.assertEqual( + output, + { + "text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.", + "chunks": [ + {'text': ' Mr.', 'timestamp': (0.0, 1.02)}, + {'text': ' Quilter', 'timestamp': (1.02, 1.18)}, + {'text': ' is', 'timestamp': (1.18, 1.44)}, + {'text': ' the', 'timestamp': (1.44, 1.58)}, + {'text': ' apostle', 'timestamp': (1.58, 1.98)}, + {'text': ' of', 'timestamp': (1.98, 2.3)}, + {'text': ' the', 'timestamp': (2.3, 2.46)}, + {'text': ' middle', 'timestamp': (2.46, 2.56)}, + {'text': ' classes,', 'timestamp': (2.56, 3.38)}, + {'text': ' and', 'timestamp': (3.38, 3.52)}, + {'text': ' we', 'timestamp': (3.52, 3.6)}, + {'text': ' are', 'timestamp': (3.6, 3.72)}, + {'text': ' glad', 'timestamp': (3.72, 4.0)}, + {'text': ' to', 'timestamp': (4.0, 4.26)}, + {'text': ' welcome', 'timestamp': (4.26, 4.54)}, + {'text': ' his', 'timestamp': (4.54, 4.92)}, + {'text': ' gospel.', 'timestamp': (4.92, 6.66)}, + ], + }, + ) + # fmt: on @slow @require_torch