From 13254591054630b08d1a1338aa5ca9674d2513ed Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 2 Mar 2023 18:12:19 +0100 Subject: [PATCH] Refactor whisper asr pipeline to include language too. (#21427) * [WIP] whisper refacto to support language output. * Handling merges. * A bit more cleanup and comments. * Many improvements. Lots of details everywhere. * Cleanup old code and tests. * Handle lone timestamp tokens (just recover when something bad happens). * Adding return_language example. * No ffmpeg. * Hmm. * Some corrections. * Both fast and slow. * New black. * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove print. * Undoing tests modifications. * Smaller test modifications. * Rename. * Remove maxDiff. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/whisper/tokenization_whisper.py | 286 ++++++++++++++++++ .../whisper/tokenization_whisper_fast.py | 11 +- .../pipelines/automatic_speech_recognition.py | 262 ++++++++-------- .../whisper/test_tokenization_whisper.py | 79 +++++ ..._pipelines_automatic_speech_recognition.py | 8 +- 5 files changed, 518 insertions(+), 128 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 7e8f675a15..3d795e5b87 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -703,3 +703,289 @@ class WhisperTokenizer(PreTrainedTokenizer): forced_tokens = self.prefix_tokens[1:] forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)] return forced_decoder_ids + + def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision): + return _decode_asr( + self, + model_outputs, + return_timestamps=return_timestamps, + return_language=return_language, + time_precision=time_precision, + ) + + +def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision): + """ + Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle + the various options not allowed in other seq2seq models + """ + + # =========== Overview ============ + # - iterate over all outputs + # - all tokens within output + # - Each token can be + # - language token + # - special token + # - timestamp token + # - text token + # - We accumulate the text tokens. + # - We split on end timestamps + # - Lots of complexity comes from stride and timestamps + + last_language = None + + def new_chunk(): + return {"language": last_language, "timestamp": [None, None], "text": ""} + + # Welcome to the state machine ! + chunks = [] + chunk = new_chunk() + time_offset = 0.0 + timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 + previous_tokens = [] + skip = False + right_stride_start = None + + all_special_ids = set(tokenizer.all_special_ids) + # - iterate over all outputs + for chunk_id, output in enumerate(model_outputs): + # We can drop everything to Python list, it's going to make + # our lives easier + token_ids = output["tokens"][0].tolist() + + # Those keep track of timestamps within strides + # Which need to be skipped and resolve all tokens in a single + # chunk. + last_timestamp = None + first_timestamp = timestamp_begin + + if "stride" in output: + chunk_len, stride_left, stride_right = output["stride"] + # Offset the timings to account for the other `model_outputs`. + time_offset -= stride_left + right_stride_start = chunk_len - stride_right + + # Keeping track of timestamps within strides + # We're going to NOT split on those, and delay until we're + # out of BOTH stride. Otherwise lots of issues occur and + # corner cases + if stride_left: + first_timestamp = stride_left / time_precision + timestamp_begin + if stride_right: + for token in reversed(token_ids): + if token >= timestamp_begin: + # There can be several token in the right stride + # But the last one is ALWAYS going to be skipped + if ( + last_timestamp is not None + and (token - timestamp_begin) * time_precision < right_stride_start + ): + break + last_timestamp = token + + current_tokens = [] + + # - all tokens within output + for i, token in enumerate(token_ids): + # 4 possible states for each token + # - 1/ Language code + # - 2/ all other special tokens (which we ignore) + # - 3/ Timestamp + # - 4/ Regular text + if token in all_special_ids: + # Either language code or other + text = tokenizer.decode([token]) + # Removing outer shell <|XX|> + text = text[2:-2] + language = LANGUAGES.get(text, None) + if language is not None: + # 1/ Indeed some language + # TODO Handle when language is different from the previous + # one, and we cannot use timestamped tokens to create chunks + if last_language and language != last_language and not return_timestamps: + previous_tokens.append(current_tokens) + resolved_tokens = _find_longest_common_sequence(previous_tokens) + resolved_text = tokenizer.decode(resolved_tokens) + chunk["text"] = resolved_text + chunks.append(chunk) + + # Flush all our temporary context + previous_tokens = [] + current_tokens = [] + chunk = new_chunk() + chunk["language"] = language + last_language = language + else: + # 2/ This is a regular special token, ignoring it + pass + elif token >= timestamp_begin: + # 3/ Timestamp token + time = (token - timestamp_begin) * time_precision + time_offset + time = round(time, 2) + if last_timestamp and token >= last_timestamp: + # Whisper outputted a timestamp token, but it falls within + # our stride, so we're going to skip it for the time being + # and resolve this later + # Skip is necessary because timestamp tokens always come + # by pair, so we need to skip the next one too (which would mark the start of another chunk). + skip = True + elif skip or (previous_tokens and token < first_timestamp): + skip = False + elif chunk["timestamp"][0] is None: + chunk["timestamp"][0] = time + else: + # This is the end of the timestamp chunk + if time == chunk["timestamp"][0]: + # This is a bug in timestamp token output + # where we're taking the duplicate token + # as a stop where it should be a start. + # This is an issue in the underlying model output + # Let's just skip it so it becomes de-factor + # a start agin + pass + else: + chunk["timestamp"][1] = time + # Handling merges. + previous_tokens.append(current_tokens) + resolved_tokens = _find_longest_common_sequence(previous_tokens) + resolved_text = tokenizer.decode(resolved_tokens) + chunk["text"] = resolved_text + chunks.append(chunk) + + # Flush all our temporary context + previous_tokens = [] + current_tokens = [] + chunk = new_chunk() + else: + # 4/ Regular token + # We just append to the list of all tokens so we can handle + # merges later and decode into text. + current_tokens.append(token) + + if "stride" in output: + time_offset += chunk_len - stride_right + + # Leftover tokens + if current_tokens: + previous_tokens.append(current_tokens) + elif not (any(p for p in previous_tokens)): + # print("Flushing previous tokens (END)") + chunk = new_chunk() + previous_tokens = [] + current_tokens = [] + + if previous_tokens: + if return_timestamps: + # Last token should always be timestamps, so there shouldn't be + # leftover + raise ValueError( + "There was an error while processing timestamps, we haven't found a timestamp as last token. Was" + " WhisperTimeStampLogitsProcessor used?" + ) + # Happens when we don't use timestamps + resolved_tokens = _find_longest_common_sequence(previous_tokens) + # print("Flushing previous tokens (FINAL)") + resolved_text = tokenizer.decode(resolved_tokens) + chunk["text"] = resolved_text + chunks.append(chunk) + + # Preparing and cleaning up the pipeline output + full_text = "".join(chunk["text"] for chunk in chunks) + if return_timestamps or return_language: + for chunk in chunks: + if not return_timestamps: + chunk.pop("timestamp") + else: + chunk["timestamp"] = tuple(chunk["timestamp"]) + if not return_language: + chunk.pop("language") + optional = {"chunks": chunks} + else: + optional = {} + return full_text, optional + + +def _find_longest_common_sequence(sequences): + # It would be much harder to do O(n) because of fault tolerance. + # We actually have a really good property which is that the total sequence + # MUST be those subsequences in order. + left_sequence = sequences[0] + left_length = len(left_sequence) + total_sequence = [] + for right_sequence in sequences[1:]: + # index = 0 + max_ = 0.0 + max_indices = (left_length, left_length, 0, 0) + # Here we're sliding matches + # [a, b, c, d] + # [c, d, f] + # = [c] == [d] + # + # [a, b, c, d] + # [c, d, f] + # = [c, d] == [c, d] + # + # + # [a, b, c, d] + # [c, d, f] + # + # = [b, c, d] == [c, d, f] + # + # [a, b, c, d] + # [c, d, f] + # + # [a, b, c] == [c, d, f] + # + # [a, b, c, d] + # [d, f] + # + # [a, b] == [d, f] + # + # [a, b, c, d] + # [f] + # + # [a] == [f] + right_length = len(right_sequence) + for i in range(1, left_length + right_length): + # epsilon to favor long perfect matches + eps = i / 10000.0 + + # Slightly convoluted because we don't want out of bound indices + # This will be necessary for a small conflict resolution optimization + # later + left_start = max(0, left_length - i) + left_stop = min(left_length, left_length + right_length - i) + left = np.array(left_sequence[left_start:left_stop]) + + right_start = max(0, i - left_length) + right_stop = min(right_length, i) + right = np.array(right_sequence[right_start:right_stop]) + + # We can only match subsequences of the same size. + if len(left) != len(right): + raise RuntimeError( + "There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference." + ) + + matches = np.sum(left == right) + matching = matches / i + eps + if matches > 1 and matching > max_: + max_ = matching + max_indices = (left_start, left_stop, right_start, right_stop) + + (left_start, left_stop, right_start, right_stop) = max_indices + + # This is a small conflict optimization since those sequences overlap + # in audio. + # We're going to give more confidence to the left sequence + # for the left of the overlap, + # and to the right of the sequence, for the right of the overlap + left_mid = (left_stop + left_start) // 2 + right_mid = (right_stop + right_start) // 2 + total_sequence.extend(left_sequence[:left_mid]) + left_sequence = right_sequence[right_mid:] + left_length = len(left_sequence) + + total_sequence.extend(left_sequence) + + return total_sequence diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index caf81da8e0..3110aac8b1 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -24,7 +24,7 @@ from ...tokenization_utils_base import BatchEncoding from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...utils import logging from .english_normalizer import EnglishTextNormalizer -from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer +from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr logger = logging.get_logger(__name__) @@ -475,3 +475,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast): forced_tokens = self.prefix_tokens[1:] forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)] return forced_decoder_ids + + def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision): + return _decode_asr( + self, + model_outputs, + return_timestamps=return_timestamps, + return_language=return_language, + time_precision=time_precision, + ) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 2780355d95..3d1a1c7348 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -82,112 +82,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, break -def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): - """ - Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since - `WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only - iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is - processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to - properly compute the final `offset`. - """ - # index of the first timestamp token - timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 - items = [] - # approximation of the token to time ratio : ~0.2seconds - time_precision = feature_extractor.chunk_length / max_source_positions - time = 0 - for seq_idx, item in enumerate(sequences): - sequence, stride = item - if isinstance(sequence, list): - sequence = np.array(sequence) - chunk_len, stride_left, stride_right = stride - sequence = sequence.squeeze(0) - # get rid of the `forced_decoder_idx` that are use to parametrize the generation - begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0 - sequence = sequence[begin_idx:] - - timestamp_tokens = sequence >= timestamp_begin - if seq_idx != 0 and sum(timestamp_tokens) > 0: - consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 - last_timestamp = np.where(timestamp_tokens)[0][-1] - consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive - time -= stride_left + stride_right - offset = int((time / feature_extractor.sampling_rate) / time_precision) - overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) - # relevant timestamps are in the overlapping part - relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0] - if relevant_timestamp.shape[0] > 0: - relevant_timestamp = ( - consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0] - ) - # if a big stride is used, we need to check some of the previous items for the best overlap - best_match = 0 - sliced_sequence = [] - for idx, previous_sequence in enumerate(reversed(items)): - previous_tokens = previous_sequence[1:-1] - if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0: - break # the previous sequence is too far in the past - if len(previous_tokens) > 0: - # find the longest common sequence between the overlapping parts - index_left, index_right, match_length = _fast_find_longest_common_sequence( - sequence[1:relevant_timestamp], previous_tokens - ) - # don't do anything if only 1 token was matched - if match_length > 1 and match_length > best_match: - best_match = match_length - best_idx = idx - end_of_curr_sequence_idx = ( - np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1 - ) - end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left - # if all the tokens are matched, suffix - if index_left == 0 and match_length == len(previous_tokens): - sliced_sequence = np.insert( - sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0] - ) - sliced_sequence[-1] = previous_sequence[-1] - # if part of the previous sequence is not taken - elif index_left >= 0: - sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx] - # let's insert the missing part of the previous sequence - previous_slice = ( - previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]] - ) - sliced_sequence = np.insert(sliced_sequence, 0, previous_slice) - sliced_sequence[-1] += offset - - if len(sliced_sequence) > 0: - items[len(items) - best_idx - 1] = sliced_sequence - items = items[: len(items) - best_idx] - sequence = sequence[end_of_curr_sequence_idx:] - - # sequence might have changed - timestamp_tokens = sequence >= timestamp_begin - consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 - if sum(timestamp_tokens) > 0: - last_timestamp = np.where(timestamp_tokens)[0][-1] - consecutive = ( - np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive - ) - - if len(consecutive) > 0: - last_slice = 0 - for current_slice in consecutive: - actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0] - sliced_tokens = sequence[last_slice:current_slice] - duration = sliced_tokens[-1] - sliced_tokens[0] - sliced_tokens[0] = actual_offset - sliced_tokens[-1] = actual_offset + duration - items.append(sliced_tokens) - last_slice = current_slice - - time += chunk_len - result = [] - for i in range(len(items)): - result += items[i].tolist() - return result - - def _fast_find_longest_common_sequence(sequence_left, sequence_right): seq_len_left = len(sequence_left) seq_len_right = len(sequence_right) @@ -384,6 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ignore_warning=None, decoder_kwargs=None, return_timestamps=None, + return_language=None, generate_kwargs=None, max_new_tokens=None, ): @@ -413,6 +308,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if return_timestamps is not None: forward_params["return_timestamps"] = return_timestamps postprocess_params["return_timestamps"] = return_timestamps + if return_language is not None: + postprocess_params["return_language"] = return_language return preprocess_params, forward_params, postprocess_params @@ -580,7 +477,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): extra = model_inputs return {"is_last": is_last, **out, **extra} - def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None): + def postprocess( + self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None + ): # Optional return types optional = {} @@ -591,12 +490,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper": raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.") + if return_language is not None and self.type != "seq2seq_whisper": + raise ValueError("Only whisper can return language for now.") + final_items = [] key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: items = outputs[key].numpy() - stride = outputs.pop("stride", None) + stride = outputs.get("stride", None) if stride is not None and self.type in {"ctc", "ctc_with_lm"}: total_n, left, right = stride # Total_n might be < logits.shape[1] @@ -605,15 +507,28 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): # This won't work with left padding (which doesn't exist right now) right_n = total_n - right items = items[:, left:right_n] - if self.type == "seq2seq_whisper" and return_timestamps and stride is not None: - # Whisper needs the stride data - items = [items, stride] final_items.append(items) - if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps: + + if stride and self.type == "seq2seq": items = _find_longest_common_sequence(final_items, self.tokenizer) - elif stride and self.type == "seq2seq_whisper" and return_timestamps: - items = _find_timestamp_sequence( - final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions + elif self.type == "seq2seq_whisper": + time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions + # Send the chunking back to seconds, it's easier to handle in whisper + sampling_rate = self.feature_extractor.sampling_rate + for output in model_outputs: + if "stride" in output: + chunk_len, stride_left, stride_right = output["stride"] + # Go back in seconds + chunk_len /= sampling_rate + stride_left /= sampling_rate + stride_right /= sampling_rate + output["stride"] = chunk_len, stride_left, stride_right + + text, optional = self.tokenizer._decode_asr( + model_outputs, + return_timestamps=return_timestamps, + return_language=return_language, + time_precision=time_precision, ) else: items = np.concatenate(final_items, axis=1) @@ -631,14 +546,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): offsets = [] for word, (start_offset, end_offset) in chunk_offset: offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) - else: + elif self.type != "seq2seq_whisper": skip_special_tokens = self.type != "ctc" text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens) - if return_timestamps and self.type == "seq2seq_whisper": - offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[ - "offsets" - ] - elif return_timestamps: + if return_timestamps: offsets = self.tokenizer.decode( items, skip_special_tokens=skip_special_tokens, output_char_offsets=True )["char_offsets"] @@ -656,14 +567,119 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)}) optional["chunks"] = chunks - elif return_timestamps and self.type == "seq2seq_whisper": - optional["chunks"] = offsets extra = defaultdict(list) for output in model_outputs: output.pop("tokens", None) output.pop("logits", None) output.pop("is_last", None) + output.pop("stride", None) for k, v in output.items(): extra[k].append(v) return {"text": text, **optional, **extra} + + +def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions): + """ + Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since + `WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only + iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is + processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to + properly compute the final `offset`. + """ + # index of the first timestamp token + timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1 + items = [] + # approximation of the token to time ratio : ~0.2seconds + time_precision = feature_extractor.chunk_length / max_source_positions + time = 0 + for seq_idx, item in enumerate(sequences): + sequence, stride = item + if isinstance(sequence, list): + sequence = np.array(sequence) + chunk_len, stride_left, stride_right = stride + sequence = sequence.squeeze(0) + # get rid of the `forced_decoder_idx` that are use to parametrize the generation + begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0 + sequence = sequence[begin_idx:] + + timestamp_tokens = sequence >= timestamp_begin + if seq_idx != 0 and sum(timestamp_tokens) > 0: + consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 + last_timestamp = np.where(timestamp_tokens)[0][-1] + consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive + time -= stride_left + stride_right + offset = int((time / feature_extractor.sampling_rate) / time_precision) + overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision) + # relevant timestamps are in the overlapping part + relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0] + if relevant_timestamp.shape[0] > 0: + relevant_timestamp = ( + consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0] + ) + # if a big stride is used, we need to check some of the previous items for the best overlap + best_match = 0 + sliced_sequence = [] + for idx, previous_sequence in enumerate(reversed(items)): + previous_tokens = previous_sequence[1:-1] + if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0: + break # the previous sequence is too far in the past + if len(previous_tokens) > 0: + # find the longest common sequence between the overlapping parts + index_left, index_right, match_length = _fast_find_longest_common_sequence( + sequence[1:relevant_timestamp], previous_tokens + ) + # don't do anything if only 1 token was matched + if match_length > 1 and match_length > best_match: + best_match = match_length + best_idx = idx + end_of_curr_sequence_idx = ( + np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1 + ) + end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left + # if all the tokens are matched, suffix + if index_left == 0 and match_length == len(previous_tokens): + sliced_sequence = np.insert( + sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0] + ) + sliced_sequence[-1] = previous_sequence[-1] + # if part of the previous sequence is not taken + elif index_left >= 0: + sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx] + # let's insert the missing part of the previous sequence + previous_slice = ( + previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]] + ) + sliced_sequence = np.insert(sliced_sequence, 0, previous_slice) + sliced_sequence[-1] += offset + + if len(sliced_sequence) > 0: + items[len(items) - best_idx - 1] = sliced_sequence + items = items[: len(items) - best_idx] + sequence = sequence[end_of_curr_sequence_idx:] + + # sequence might have changed + timestamp_tokens = sequence >= timestamp_begin + consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1 + if sum(timestamp_tokens) > 0: + last_timestamp = np.where(timestamp_tokens)[0][-1] + consecutive = ( + np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive + ) + + if len(consecutive) > 0: + last_slice = 0 + for current_slice in consecutive: + actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0] + sliced_tokens = sequence[last_slice:current_slice] + duration = sliced_tokens[-1] - sliced_tokens[0] + sliced_tokens[0] = actual_offset + sliced_tokens[-1] = actual_offset + duration + items.append(sliced_tokens) + last_slice = current_slice + + time += chunk_len + result = [] + for i in range(len(items)): + result += items[i].tolist() + return result diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index acfb577cef..9ceef149fa 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -15,6 +15,7 @@ import unittest from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast +from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence from transformers.testing_utils import slow from ...test_tokenization_common import TokenizerTesterMixin @@ -115,6 +116,84 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase): expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False ) + def test_output_offsets(self): + tokenizer = self.get_tokenizer() + previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612] + self.assertEqual( + tokenizer.decode(previous_sequence, output_offsets=True), + { + "text": " not worth thinking about.", + "offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}], + }, + ) + + # Merge when the previous sequence is a suffix of the next sequence + # fmt: off + next_sequences_1 = [50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257] + # fmt: on + self.assertEqual( + tokenizer.decode(next_sequences_1, output_offsets=True), + { + "text": ( + " of spectators, retrievality is not worth thinking about. His instant panic was followed by a" + " small, sharp blow high on his chest.<|endoftext|>" + ), + "offsets": [ + {"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)}, + { + "text": " His instant panic was followed by a small, sharp blow high on his chest.", + "timestamp": (5.0, 9.4), + }, + ], + }, + ) + + def test_find_longest_common_subsequence(self): + previous_sequence = [1, 2, 3] + next_sequence = [2, 3, 4, 5] + merge = _find_longest_common_sequence([previous_sequence, next_sequence]) + self.assertEqual(merge, [1, 2, 3, 4, 5]) + + # Now previous is larger than next. + # We merge what we can and remove the extra right side of the left sequence + previous_sequence = [1, 2, 3, 4, 5, 6, 7] + next_sequence = [2, 3, 4, 5] + merge = _find_longest_common_sequence([previous_sequence, next_sequence]) + self.assertEqual(merge, [1, 2, 3, 4, 5]) + + # Nothing in common + previous_sequence = [1, 2, 3] + next_sequence = [4, 5, 6] + merge = _find_longest_common_sequence([previous_sequence, next_sequence]) + self.assertEqual(merge, [1, 2, 3, 4, 5, 6]) + + # Some errors in the overlap. + # We take from previous on the left, from the next on the right of the overlap + previous_sequence = [1, 2, 3, 4, 99] + next_sequence = [2, 98, 4, 5, 6] + merge = _find_longest_common_sequence([previous_sequence, next_sequence]) + self.assertEqual(merge, [1, 2, 3, 4, 5, 6]) + + # We take from previous on the left, from the next on the right of the overlap + previous_sequence = [1, 2, 99, 4, 5] + next_sequence = [2, 3, 4, 98, 6] + merge = _find_longest_common_sequence([previous_sequence, next_sequence]) + self.assertEqual(merge, [1, 2, 99, 4, 98, 6]) + + # This works on 3 sequences + seq1 = [1, 2, 3] + seq2 = [2, 3, 4] + seq3 = [3, 4, 5] + merge = _find_longest_common_sequence([seq1, seq2, seq3]) + self.assertEqual(merge, [1, 2, 3, 4, 5]) + + # This works on 3 sequences with errors + seq1 = [1, 2, 3, 98, 5] + seq2 = [2, 99, 4, 5, 6, 7] + seq3 = [4, 97, 6, 7, 8] + merge = _find_longest_common_sequence([seq1, seq2, seq3]) + self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8]) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en" diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index fcabc0ad35..d266438ac3 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -538,7 +538,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): "tight-loan cloth that was the only garment he wore, the " "cut" ), - "timestamp": (5.5, 11.94), + "timestamp": (5.5, 11.95), }, { "text": ( @@ -546,15 +546,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): "overstrained eyes, even the soaring arena around him " "with" ), - "timestamp": (11.94, 19.6), + "timestamp": (11.95, 19.61), }, { "text": " the thousands of spectators, retrievality is not worth thinking about.", - "timestamp": (19.6, 26.66), + "timestamp": (19.61, 25.0), }, { "text": " His instant panic was followed by a small, sharp blow high on his chest.", - "timestamp": (26.66, 31.06), + "timestamp": (25.0, 29.4), }, ], "text": (