Refactor whisper asr pipeline to include language too. (#21427)
* [WIP] whisper refacto to support language output. * Handling merges. * A bit more cleanup and comments. * Many improvements. Lots of details everywhere. * Cleanup old code and tests. * Handle lone timestamp tokens (just recover when something bad happens). * Adding return_language example. * No ffmpeg. * Hmm. * Some corrections. * Both fast and slow. * New black. * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove print. * Undoing tests modifications. * Smaller test modifications. * Rename. * Remove maxDiff. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -703,3 +703,289 @@ class WhisperTokenizer(PreTrainedTokenizer):
|
|||||||
forced_tokens = self.prefix_tokens[1:]
|
forced_tokens = self.prefix_tokens[1:]
|
||||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
|
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
|
||||||
return forced_decoder_ids
|
return forced_decoder_ids
|
||||||
|
|
||||||
|
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
|
||||||
|
return _decode_asr(
|
||||||
|
self,
|
||||||
|
model_outputs,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
return_language=return_language,
|
||||||
|
time_precision=time_precision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision):
|
||||||
|
"""
|
||||||
|
Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
|
||||||
|
the various options not allowed in other seq2seq models
|
||||||
|
"""
|
||||||
|
|
||||||
|
# =========== Overview ============
|
||||||
|
# - iterate over all outputs
|
||||||
|
# - all tokens within output
|
||||||
|
# - Each token can be
|
||||||
|
# - language token
|
||||||
|
# - special token
|
||||||
|
# - timestamp token
|
||||||
|
# - text token
|
||||||
|
# - We accumulate the text tokens.
|
||||||
|
# - We split on end timestamps
|
||||||
|
# - Lots of complexity comes from stride and timestamps
|
||||||
|
|
||||||
|
last_language = None
|
||||||
|
|
||||||
|
def new_chunk():
|
||||||
|
return {"language": last_language, "timestamp": [None, None], "text": ""}
|
||||||
|
|
||||||
|
# Welcome to the state machine !
|
||||||
|
chunks = []
|
||||||
|
chunk = new_chunk()
|
||||||
|
time_offset = 0.0
|
||||||
|
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||||
|
previous_tokens = []
|
||||||
|
skip = False
|
||||||
|
right_stride_start = None
|
||||||
|
|
||||||
|
all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
# - iterate over all outputs
|
||||||
|
for chunk_id, output in enumerate(model_outputs):
|
||||||
|
# We can drop everything to Python list, it's going to make
|
||||||
|
# our lives easier
|
||||||
|
token_ids = output["tokens"][0].tolist()
|
||||||
|
|
||||||
|
# Those keep track of timestamps within strides
|
||||||
|
# Which need to be skipped and resolve all tokens in a single
|
||||||
|
# chunk.
|
||||||
|
last_timestamp = None
|
||||||
|
first_timestamp = timestamp_begin
|
||||||
|
|
||||||
|
if "stride" in output:
|
||||||
|
chunk_len, stride_left, stride_right = output["stride"]
|
||||||
|
# Offset the timings to account for the other `model_outputs`.
|
||||||
|
time_offset -= stride_left
|
||||||
|
right_stride_start = chunk_len - stride_right
|
||||||
|
|
||||||
|
# Keeping track of timestamps within strides
|
||||||
|
# We're going to NOT split on those, and delay until we're
|
||||||
|
# out of BOTH stride. Otherwise lots of issues occur and
|
||||||
|
# corner cases
|
||||||
|
if stride_left:
|
||||||
|
first_timestamp = stride_left / time_precision + timestamp_begin
|
||||||
|
if stride_right:
|
||||||
|
for token in reversed(token_ids):
|
||||||
|
if token >= timestamp_begin:
|
||||||
|
# There can be several token in the right stride
|
||||||
|
# But the last one is ALWAYS going to be skipped
|
||||||
|
if (
|
||||||
|
last_timestamp is not None
|
||||||
|
and (token - timestamp_begin) * time_precision < right_stride_start
|
||||||
|
):
|
||||||
|
break
|
||||||
|
last_timestamp = token
|
||||||
|
|
||||||
|
current_tokens = []
|
||||||
|
|
||||||
|
# - all tokens within output
|
||||||
|
for i, token in enumerate(token_ids):
|
||||||
|
# 4 possible states for each token
|
||||||
|
# - 1/ Language code
|
||||||
|
# - 2/ all other special tokens (which we ignore)
|
||||||
|
# - 3/ Timestamp
|
||||||
|
# - 4/ Regular text
|
||||||
|
if token in all_special_ids:
|
||||||
|
# Either language code or other
|
||||||
|
text = tokenizer.decode([token])
|
||||||
|
# Removing outer shell <|XX|>
|
||||||
|
text = text[2:-2]
|
||||||
|
language = LANGUAGES.get(text, None)
|
||||||
|
if language is not None:
|
||||||
|
# 1/ Indeed some language
|
||||||
|
# TODO Handle when language is different from the previous
|
||||||
|
# one, and we cannot use timestamped tokens to create chunks
|
||||||
|
if last_language and language != last_language and not return_timestamps:
|
||||||
|
previous_tokens.append(current_tokens)
|
||||||
|
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
||||||
|
resolved_text = tokenizer.decode(resolved_tokens)
|
||||||
|
chunk["text"] = resolved_text
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Flush all our temporary context
|
||||||
|
previous_tokens = []
|
||||||
|
current_tokens = []
|
||||||
|
chunk = new_chunk()
|
||||||
|
chunk["language"] = language
|
||||||
|
last_language = language
|
||||||
|
else:
|
||||||
|
# 2/ This is a regular special token, ignoring it
|
||||||
|
pass
|
||||||
|
elif token >= timestamp_begin:
|
||||||
|
# 3/ Timestamp token
|
||||||
|
time = (token - timestamp_begin) * time_precision + time_offset
|
||||||
|
time = round(time, 2)
|
||||||
|
if last_timestamp and token >= last_timestamp:
|
||||||
|
# Whisper outputted a timestamp token, but it falls within
|
||||||
|
# our stride, so we're going to skip it for the time being
|
||||||
|
# and resolve this later
|
||||||
|
# Skip is necessary because timestamp tokens always come
|
||||||
|
# by pair, so we need to skip the next one too (which would mark the start of another chunk).
|
||||||
|
skip = True
|
||||||
|
elif skip or (previous_tokens and token < first_timestamp):
|
||||||
|
skip = False
|
||||||
|
elif chunk["timestamp"][0] is None:
|
||||||
|
chunk["timestamp"][0] = time
|
||||||
|
else:
|
||||||
|
# This is the end of the timestamp chunk
|
||||||
|
if time == chunk["timestamp"][0]:
|
||||||
|
# This is a bug in timestamp token output
|
||||||
|
# where we're taking the duplicate token
|
||||||
|
# as a stop where it should be a start.
|
||||||
|
# This is an issue in the underlying model output
|
||||||
|
# Let's just skip it so it becomes de-factor
|
||||||
|
# a start agin
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
chunk["timestamp"][1] = time
|
||||||
|
# Handling merges.
|
||||||
|
previous_tokens.append(current_tokens)
|
||||||
|
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
||||||
|
resolved_text = tokenizer.decode(resolved_tokens)
|
||||||
|
chunk["text"] = resolved_text
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Flush all our temporary context
|
||||||
|
previous_tokens = []
|
||||||
|
current_tokens = []
|
||||||
|
chunk = new_chunk()
|
||||||
|
else:
|
||||||
|
# 4/ Regular token
|
||||||
|
# We just append to the list of all tokens so we can handle
|
||||||
|
# merges later and decode into text.
|
||||||
|
current_tokens.append(token)
|
||||||
|
|
||||||
|
if "stride" in output:
|
||||||
|
time_offset += chunk_len - stride_right
|
||||||
|
|
||||||
|
# Leftover tokens
|
||||||
|
if current_tokens:
|
||||||
|
previous_tokens.append(current_tokens)
|
||||||
|
elif not (any(p for p in previous_tokens)):
|
||||||
|
# print("Flushing previous tokens (END)")
|
||||||
|
chunk = new_chunk()
|
||||||
|
previous_tokens = []
|
||||||
|
current_tokens = []
|
||||||
|
|
||||||
|
if previous_tokens:
|
||||||
|
if return_timestamps:
|
||||||
|
# Last token should always be timestamps, so there shouldn't be
|
||||||
|
# leftover
|
||||||
|
raise ValueError(
|
||||||
|
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
|
||||||
|
" WhisperTimeStampLogitsProcessor used?"
|
||||||
|
)
|
||||||
|
# Happens when we don't use timestamps
|
||||||
|
resolved_tokens = _find_longest_common_sequence(previous_tokens)
|
||||||
|
# print("Flushing previous tokens (FINAL)")
|
||||||
|
resolved_text = tokenizer.decode(resolved_tokens)
|
||||||
|
chunk["text"] = resolved_text
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
# Preparing and cleaning up the pipeline output
|
||||||
|
full_text = "".join(chunk["text"] for chunk in chunks)
|
||||||
|
if return_timestamps or return_language:
|
||||||
|
for chunk in chunks:
|
||||||
|
if not return_timestamps:
|
||||||
|
chunk.pop("timestamp")
|
||||||
|
else:
|
||||||
|
chunk["timestamp"] = tuple(chunk["timestamp"])
|
||||||
|
if not return_language:
|
||||||
|
chunk.pop("language")
|
||||||
|
optional = {"chunks": chunks}
|
||||||
|
else:
|
||||||
|
optional = {}
|
||||||
|
return full_text, optional
|
||||||
|
|
||||||
|
|
||||||
|
def _find_longest_common_sequence(sequences):
|
||||||
|
# It would be much harder to do O(n) because of fault tolerance.
|
||||||
|
# We actually have a really good property which is that the total sequence
|
||||||
|
# MUST be those subsequences in order.
|
||||||
|
left_sequence = sequences[0]
|
||||||
|
left_length = len(left_sequence)
|
||||||
|
total_sequence = []
|
||||||
|
for right_sequence in sequences[1:]:
|
||||||
|
# index = 0
|
||||||
|
max_ = 0.0
|
||||||
|
max_indices = (left_length, left_length, 0, 0)
|
||||||
|
# Here we're sliding matches
|
||||||
|
# [a, b, c, d]
|
||||||
|
# [c, d, f]
|
||||||
|
# = [c] == [d]
|
||||||
|
#
|
||||||
|
# [a, b, c, d]
|
||||||
|
# [c, d, f]
|
||||||
|
# = [c, d] == [c, d]
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# [a, b, c, d]
|
||||||
|
# [c, d, f]
|
||||||
|
#
|
||||||
|
# = [b, c, d] == [c, d, f]
|
||||||
|
#
|
||||||
|
# [a, b, c, d]
|
||||||
|
# [c, d, f]
|
||||||
|
#
|
||||||
|
# [a, b, c] == [c, d, f]
|
||||||
|
#
|
||||||
|
# [a, b, c, d]
|
||||||
|
# [d, f]
|
||||||
|
#
|
||||||
|
# [a, b] == [d, f]
|
||||||
|
#
|
||||||
|
# [a, b, c, d]
|
||||||
|
# [f]
|
||||||
|
#
|
||||||
|
# [a] == [f]
|
||||||
|
right_length = len(right_sequence)
|
||||||
|
for i in range(1, left_length + right_length):
|
||||||
|
# epsilon to favor long perfect matches
|
||||||
|
eps = i / 10000.0
|
||||||
|
|
||||||
|
# Slightly convoluted because we don't want out of bound indices
|
||||||
|
# This will be necessary for a small conflict resolution optimization
|
||||||
|
# later
|
||||||
|
left_start = max(0, left_length - i)
|
||||||
|
left_stop = min(left_length, left_length + right_length - i)
|
||||||
|
left = np.array(left_sequence[left_start:left_stop])
|
||||||
|
|
||||||
|
right_start = max(0, i - left_length)
|
||||||
|
right_stop = min(right_length, i)
|
||||||
|
right = np.array(right_sequence[right_start:right_stop])
|
||||||
|
|
||||||
|
# We can only match subsequences of the same size.
|
||||||
|
if len(left) != len(right):
|
||||||
|
raise RuntimeError(
|
||||||
|
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = np.sum(left == right)
|
||||||
|
matching = matches / i + eps
|
||||||
|
if matches > 1 and matching > max_:
|
||||||
|
max_ = matching
|
||||||
|
max_indices = (left_start, left_stop, right_start, right_stop)
|
||||||
|
|
||||||
|
(left_start, left_stop, right_start, right_stop) = max_indices
|
||||||
|
|
||||||
|
# This is a small conflict optimization since those sequences overlap
|
||||||
|
# in audio.
|
||||||
|
# We're going to give more confidence to the left sequence
|
||||||
|
# for the left of the overlap,
|
||||||
|
# and to the right of the sequence, for the right of the overlap
|
||||||
|
left_mid = (left_stop + left_start) // 2
|
||||||
|
right_mid = (right_stop + right_start) // 2
|
||||||
|
total_sequence.extend(left_sequence[:left_mid])
|
||||||
|
left_sequence = right_sequence[right_mid:]
|
||||||
|
left_length = len(left_sequence)
|
||||||
|
|
||||||
|
total_sequence.extend(left_sequence)
|
||||||
|
|
||||||
|
return total_sequence
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ...tokenization_utils_base import BatchEncoding
|
|||||||
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
from ...tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .english_normalizer import EnglishTextNormalizer
|
from .english_normalizer import EnglishTextNormalizer
|
||||||
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer
|
from .tokenization_whisper import LANGUAGES, TASK_IDS, TO_LANGUAGE_CODE, WhisperTokenizer, _decode_asr
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -475,3 +475,12 @@ class WhisperTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
forced_tokens = self.prefix_tokens[1:]
|
forced_tokens = self.prefix_tokens[1:]
|
||||||
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
|
forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
|
||||||
return forced_decoder_ids
|
return forced_decoder_ids
|
||||||
|
|
||||||
|
def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
|
||||||
|
return _decode_asr(
|
||||||
|
self,
|
||||||
|
model_outputs,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
return_language=return_language,
|
||||||
|
time_precision=time_precision,
|
||||||
|
)
|
||||||
|
|||||||
@@ -82,112 +82,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
|||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
|
||||||
"""
|
|
||||||
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
|
||||||
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
|
||||||
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
|
||||||
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
|
||||||
properly compute the final `offset`.
|
|
||||||
"""
|
|
||||||
# index of the first timestamp token
|
|
||||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
|
||||||
items = []
|
|
||||||
# approximation of the token to time ratio : ~0.2seconds
|
|
||||||
time_precision = feature_extractor.chunk_length / max_source_positions
|
|
||||||
time = 0
|
|
||||||
for seq_idx, item in enumerate(sequences):
|
|
||||||
sequence, stride = item
|
|
||||||
if isinstance(sequence, list):
|
|
||||||
sequence = np.array(sequence)
|
|
||||||
chunk_len, stride_left, stride_right = stride
|
|
||||||
sequence = sequence.squeeze(0)
|
|
||||||
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
|
||||||
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
|
||||||
sequence = sequence[begin_idx:]
|
|
||||||
|
|
||||||
timestamp_tokens = sequence >= timestamp_begin
|
|
||||||
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
|
||||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
|
||||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
|
||||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
|
||||||
time -= stride_left + stride_right
|
|
||||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
|
||||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
|
||||||
# relevant timestamps are in the overlapping part
|
|
||||||
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
|
||||||
if relevant_timestamp.shape[0] > 0:
|
|
||||||
relevant_timestamp = (
|
|
||||||
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
|
||||||
)
|
|
||||||
# if a big stride is used, we need to check some of the previous items for the best overlap
|
|
||||||
best_match = 0
|
|
||||||
sliced_sequence = []
|
|
||||||
for idx, previous_sequence in enumerate(reversed(items)):
|
|
||||||
previous_tokens = previous_sequence[1:-1]
|
|
||||||
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
|
||||||
break # the previous sequence is too far in the past
|
|
||||||
if len(previous_tokens) > 0:
|
|
||||||
# find the longest common sequence between the overlapping parts
|
|
||||||
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
|
||||||
sequence[1:relevant_timestamp], previous_tokens
|
|
||||||
)
|
|
||||||
# don't do anything if only 1 token was matched
|
|
||||||
if match_length > 1 and match_length > best_match:
|
|
||||||
best_match = match_length
|
|
||||||
best_idx = idx
|
|
||||||
end_of_curr_sequence_idx = (
|
|
||||||
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
|
||||||
)
|
|
||||||
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
|
||||||
# if all the tokens are matched, suffix
|
|
||||||
if index_left == 0 and match_length == len(previous_tokens):
|
|
||||||
sliced_sequence = np.insert(
|
|
||||||
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
|
||||||
)
|
|
||||||
sliced_sequence[-1] = previous_sequence[-1]
|
|
||||||
# if part of the previous sequence is not taken
|
|
||||||
elif index_left >= 0:
|
|
||||||
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
|
||||||
# let's insert the missing part of the previous sequence
|
|
||||||
previous_slice = (
|
|
||||||
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
|
||||||
)
|
|
||||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
|
||||||
sliced_sequence[-1] += offset
|
|
||||||
|
|
||||||
if len(sliced_sequence) > 0:
|
|
||||||
items[len(items) - best_idx - 1] = sliced_sequence
|
|
||||||
items = items[: len(items) - best_idx]
|
|
||||||
sequence = sequence[end_of_curr_sequence_idx:]
|
|
||||||
|
|
||||||
# sequence might have changed
|
|
||||||
timestamp_tokens = sequence >= timestamp_begin
|
|
||||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
|
||||||
if sum(timestamp_tokens) > 0:
|
|
||||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
|
||||||
consecutive = (
|
|
||||||
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(consecutive) > 0:
|
|
||||||
last_slice = 0
|
|
||||||
for current_slice in consecutive:
|
|
||||||
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
|
||||||
sliced_tokens = sequence[last_slice:current_slice]
|
|
||||||
duration = sliced_tokens[-1] - sliced_tokens[0]
|
|
||||||
sliced_tokens[0] = actual_offset
|
|
||||||
sliced_tokens[-1] = actual_offset + duration
|
|
||||||
items.append(sliced_tokens)
|
|
||||||
last_slice = current_slice
|
|
||||||
|
|
||||||
time += chunk_len
|
|
||||||
result = []
|
|
||||||
for i in range(len(items)):
|
|
||||||
result += items[i].tolist()
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
||||||
seq_len_left = len(sequence_left)
|
seq_len_left = len(sequence_left)
|
||||||
seq_len_right = len(sequence_right)
|
seq_len_right = len(sequence_right)
|
||||||
@@ -384,6 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
ignore_warning=None,
|
ignore_warning=None,
|
||||||
decoder_kwargs=None,
|
decoder_kwargs=None,
|
||||||
return_timestamps=None,
|
return_timestamps=None,
|
||||||
|
return_language=None,
|
||||||
generate_kwargs=None,
|
generate_kwargs=None,
|
||||||
max_new_tokens=None,
|
max_new_tokens=None,
|
||||||
):
|
):
|
||||||
@@ -413,6 +308,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if return_timestamps is not None:
|
if return_timestamps is not None:
|
||||||
forward_params["return_timestamps"] = return_timestamps
|
forward_params["return_timestamps"] = return_timestamps
|
||||||
postprocess_params["return_timestamps"] = return_timestamps
|
postprocess_params["return_timestamps"] = return_timestamps
|
||||||
|
if return_language is not None:
|
||||||
|
postprocess_params["return_language"] = return_language
|
||||||
|
|
||||||
return preprocess_params, forward_params, postprocess_params
|
return preprocess_params, forward_params, postprocess_params
|
||||||
|
|
||||||
@@ -580,7 +477,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
extra = model_inputs
|
extra = model_inputs
|
||||||
return {"is_last": is_last, **out, **extra}
|
return {"is_last": is_last, **out, **extra}
|
||||||
|
|
||||||
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None):
|
def postprocess(
|
||||||
|
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
|
||||||
|
):
|
||||||
# Optional return types
|
# Optional return types
|
||||||
optional = {}
|
optional = {}
|
||||||
|
|
||||||
@@ -591,12 +490,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
|
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
|
||||||
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
|
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
|
||||||
|
|
||||||
|
if return_language is not None and self.type != "seq2seq_whisper":
|
||||||
|
raise ValueError("Only whisper can return language for now.")
|
||||||
|
|
||||||
final_items = []
|
final_items = []
|
||||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||||
stride = None
|
stride = None
|
||||||
for outputs in model_outputs:
|
for outputs in model_outputs:
|
||||||
items = outputs[key].numpy()
|
items = outputs[key].numpy()
|
||||||
stride = outputs.pop("stride", None)
|
stride = outputs.get("stride", None)
|
||||||
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
||||||
total_n, left, right = stride
|
total_n, left, right = stride
|
||||||
# Total_n might be < logits.shape[1]
|
# Total_n might be < logits.shape[1]
|
||||||
@@ -605,15 +507,28 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
# This won't work with left padding (which doesn't exist right now)
|
# This won't work with left padding (which doesn't exist right now)
|
||||||
right_n = total_n - right
|
right_n = total_n - right
|
||||||
items = items[:, left:right_n]
|
items = items[:, left:right_n]
|
||||||
if self.type == "seq2seq_whisper" and return_timestamps and stride is not None:
|
|
||||||
# Whisper needs the stride data
|
|
||||||
items = [items, stride]
|
|
||||||
final_items.append(items)
|
final_items.append(items)
|
||||||
if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps:
|
|
||||||
|
if stride and self.type == "seq2seq":
|
||||||
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
||||||
elif stride and self.type == "seq2seq_whisper" and return_timestamps:
|
elif self.type == "seq2seq_whisper":
|
||||||
items = _find_timestamp_sequence(
|
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
|
||||||
final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions
|
# Send the chunking back to seconds, it's easier to handle in whisper
|
||||||
|
sampling_rate = self.feature_extractor.sampling_rate
|
||||||
|
for output in model_outputs:
|
||||||
|
if "stride" in output:
|
||||||
|
chunk_len, stride_left, stride_right = output["stride"]
|
||||||
|
# Go back in seconds
|
||||||
|
chunk_len /= sampling_rate
|
||||||
|
stride_left /= sampling_rate
|
||||||
|
stride_right /= sampling_rate
|
||||||
|
output["stride"] = chunk_len, stride_left, stride_right
|
||||||
|
|
||||||
|
text, optional = self.tokenizer._decode_asr(
|
||||||
|
model_outputs,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
return_language=return_language,
|
||||||
|
time_precision=time_precision,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
items = np.concatenate(final_items, axis=1)
|
items = np.concatenate(final_items, axis=1)
|
||||||
@@ -631,14 +546,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
offsets = []
|
offsets = []
|
||||||
for word, (start_offset, end_offset) in chunk_offset:
|
for word, (start_offset, end_offset) in chunk_offset:
|
||||||
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||||
else:
|
elif self.type != "seq2seq_whisper":
|
||||||
skip_special_tokens = self.type != "ctc"
|
skip_special_tokens = self.type != "ctc"
|
||||||
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
||||||
if return_timestamps and self.type == "seq2seq_whisper":
|
if return_timestamps:
|
||||||
offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[
|
|
||||||
"offsets"
|
|
||||||
]
|
|
||||||
elif return_timestamps:
|
|
||||||
offsets = self.tokenizer.decode(
|
offsets = self.tokenizer.decode(
|
||||||
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
||||||
)["char_offsets"]
|
)["char_offsets"]
|
||||||
@@ -656,14 +567,119 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
|
|
||||||
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
||||||
optional["chunks"] = chunks
|
optional["chunks"] = chunks
|
||||||
elif return_timestamps and self.type == "seq2seq_whisper":
|
|
||||||
optional["chunks"] = offsets
|
|
||||||
|
|
||||||
extra = defaultdict(list)
|
extra = defaultdict(list)
|
||||||
for output in model_outputs:
|
for output in model_outputs:
|
||||||
output.pop("tokens", None)
|
output.pop("tokens", None)
|
||||||
output.pop("logits", None)
|
output.pop("logits", None)
|
||||||
output.pop("is_last", None)
|
output.pop("is_last", None)
|
||||||
|
output.pop("stride", None)
|
||||||
for k, v in output.items():
|
for k, v in output.items():
|
||||||
extra[k].append(v)
|
extra[k].append(v)
|
||||||
return {"text": text, **optional, **extra}
|
return {"text": text, **optional, **extra}
|
||||||
|
|
||||||
|
|
||||||
|
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
||||||
|
"""
|
||||||
|
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
||||||
|
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
||||||
|
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
||||||
|
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
||||||
|
properly compute the final `offset`.
|
||||||
|
"""
|
||||||
|
# index of the first timestamp token
|
||||||
|
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||||
|
items = []
|
||||||
|
# approximation of the token to time ratio : ~0.2seconds
|
||||||
|
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||||
|
time = 0
|
||||||
|
for seq_idx, item in enumerate(sequences):
|
||||||
|
sequence, stride = item
|
||||||
|
if isinstance(sequence, list):
|
||||||
|
sequence = np.array(sequence)
|
||||||
|
chunk_len, stride_left, stride_right = stride
|
||||||
|
sequence = sequence.squeeze(0)
|
||||||
|
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
||||||
|
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
||||||
|
sequence = sequence[begin_idx:]
|
||||||
|
|
||||||
|
timestamp_tokens = sequence >= timestamp_begin
|
||||||
|
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||||
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||||
|
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||||
|
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||||
|
time -= stride_left + stride_right
|
||||||
|
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||||
|
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||||
|
# relevant timestamps are in the overlapping part
|
||||||
|
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
||||||
|
if relevant_timestamp.shape[0] > 0:
|
||||||
|
relevant_timestamp = (
|
||||||
|
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
||||||
|
)
|
||||||
|
# if a big stride is used, we need to check some of the previous items for the best overlap
|
||||||
|
best_match = 0
|
||||||
|
sliced_sequence = []
|
||||||
|
for idx, previous_sequence in enumerate(reversed(items)):
|
||||||
|
previous_tokens = previous_sequence[1:-1]
|
||||||
|
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
||||||
|
break # the previous sequence is too far in the past
|
||||||
|
if len(previous_tokens) > 0:
|
||||||
|
# find the longest common sequence between the overlapping parts
|
||||||
|
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||||
|
sequence[1:relevant_timestamp], previous_tokens
|
||||||
|
)
|
||||||
|
# don't do anything if only 1 token was matched
|
||||||
|
if match_length > 1 and match_length > best_match:
|
||||||
|
best_match = match_length
|
||||||
|
best_idx = idx
|
||||||
|
end_of_curr_sequence_idx = (
|
||||||
|
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
||||||
|
)
|
||||||
|
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
||||||
|
# if all the tokens are matched, suffix
|
||||||
|
if index_left == 0 and match_length == len(previous_tokens):
|
||||||
|
sliced_sequence = np.insert(
|
||||||
|
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
||||||
|
)
|
||||||
|
sliced_sequence[-1] = previous_sequence[-1]
|
||||||
|
# if part of the previous sequence is not taken
|
||||||
|
elif index_left >= 0:
|
||||||
|
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
||||||
|
# let's insert the missing part of the previous sequence
|
||||||
|
previous_slice = (
|
||||||
|
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
||||||
|
)
|
||||||
|
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
||||||
|
sliced_sequence[-1] += offset
|
||||||
|
|
||||||
|
if len(sliced_sequence) > 0:
|
||||||
|
items[len(items) - best_idx - 1] = sliced_sequence
|
||||||
|
items = items[: len(items) - best_idx]
|
||||||
|
sequence = sequence[end_of_curr_sequence_idx:]
|
||||||
|
|
||||||
|
# sequence might have changed
|
||||||
|
timestamp_tokens = sequence >= timestamp_begin
|
||||||
|
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||||
|
if sum(timestamp_tokens) > 0:
|
||||||
|
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||||
|
consecutive = (
|
||||||
|
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(consecutive) > 0:
|
||||||
|
last_slice = 0
|
||||||
|
for current_slice in consecutive:
|
||||||
|
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
||||||
|
sliced_tokens = sequence[last_slice:current_slice]
|
||||||
|
duration = sliced_tokens[-1] - sliced_tokens[0]
|
||||||
|
sliced_tokens[0] = actual_offset
|
||||||
|
sliced_tokens[-1] = actual_offset + duration
|
||||||
|
items.append(sliced_tokens)
|
||||||
|
last_slice = current_slice
|
||||||
|
|
||||||
|
time += chunk_len
|
||||||
|
result = []
|
||||||
|
for i in range(len(items)):
|
||||||
|
result += items[i].tolist()
|
||||||
|
return result
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||||
|
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
from ...test_tokenization_common import TokenizerTesterMixin
|
from ...test_tokenization_common import TokenizerTesterMixin
|
||||||
@@ -115,6 +116,84 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
|
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_output_offsets(self):
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.decode(previous_sequence, output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": " not worth thinking about.",
|
||||||
|
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge when the previous sequence is a suffix of the next sequence
|
||||||
|
# fmt: off
|
||||||
|
next_sequences_1 = [50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||||
|
# fmt: on
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.decode(next_sequences_1, output_offsets=True),
|
||||||
|
{
|
||||||
|
"text": (
|
||||||
|
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
|
||||||
|
" small, sharp blow high on his chest.<|endoftext|>"
|
||||||
|
),
|
||||||
|
"offsets": [
|
||||||
|
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
|
||||||
|
{
|
||||||
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
|
"timestamp": (5.0, 9.4),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_find_longest_common_subsequence(self):
|
||||||
|
previous_sequence = [1, 2, 3]
|
||||||
|
next_sequence = [2, 3, 4, 5]
|
||||||
|
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||||
|
self.assertEqual(merge, [1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
# Now previous is larger than next.
|
||||||
|
# We merge what we can and remove the extra right side of the left sequence
|
||||||
|
previous_sequence = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
next_sequence = [2, 3, 4, 5]
|
||||||
|
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||||
|
self.assertEqual(merge, [1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
# Nothing in common
|
||||||
|
previous_sequence = [1, 2, 3]
|
||||||
|
next_sequence = [4, 5, 6]
|
||||||
|
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||||
|
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
# Some errors in the overlap.
|
||||||
|
# We take from previous on the left, from the next on the right of the overlap
|
||||||
|
previous_sequence = [1, 2, 3, 4, 99]
|
||||||
|
next_sequence = [2, 98, 4, 5, 6]
|
||||||
|
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||||
|
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
|
||||||
|
|
||||||
|
# We take from previous on the left, from the next on the right of the overlap
|
||||||
|
previous_sequence = [1, 2, 99, 4, 5]
|
||||||
|
next_sequence = [2, 3, 4, 98, 6]
|
||||||
|
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||||
|
self.assertEqual(merge, [1, 2, 99, 4, 98, 6])
|
||||||
|
|
||||||
|
# This works on 3 sequences
|
||||||
|
seq1 = [1, 2, 3]
|
||||||
|
seq2 = [2, 3, 4]
|
||||||
|
seq3 = [3, 4, 5]
|
||||||
|
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||||
|
self.assertEqual(merge, [1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
# This works on 3 sequences with errors
|
||||||
|
seq1 = [1, 2, 3, 98, 5]
|
||||||
|
seq2 = [2, 99, 4, 5, 6, 7]
|
||||||
|
seq3 = [4, 97, 6, 7, 8]
|
||||||
|
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||||
|
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
||||||
|
|
||||||
|
|
||||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||||
checkpoint_name = "openai/whisper-small.en"
|
checkpoint_name = "openai/whisper-small.en"
|
||||||
|
|||||||
@@ -538,7 +538,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
"tight-loan cloth that was the only garment he wore, the "
|
"tight-loan cloth that was the only garment he wore, the "
|
||||||
"cut"
|
"cut"
|
||||||
),
|
),
|
||||||
"timestamp": (5.5, 11.94),
|
"timestamp": (5.5, 11.95),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"text": (
|
"text": (
|
||||||
@@ -546,15 +546,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
"overstrained eyes, even the soaring arena around him "
|
"overstrained eyes, even the soaring arena around him "
|
||||||
"with"
|
"with"
|
||||||
),
|
),
|
||||||
"timestamp": (11.94, 19.6),
|
"timestamp": (11.95, 19.61),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"text": " the thousands of spectators, retrievality is not worth thinking about.",
|
"text": " the thousands of spectators, retrievality is not worth thinking about.",
|
||||||
"timestamp": (19.6, 26.66),
|
"timestamp": (19.61, 25.0),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||||
"timestamp": (26.66, 31.06),
|
"timestamp": (25.0, 29.4),
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"text": (
|
"text": (
|
||||||
|
|||||||
Reference in New Issue
Block a user