Refactor whisper asr pipeline to include language too. (#21427)
* [WIP] whisper refacto to support language output. * Handling merges. * A bit more cleanup and comments. * Many improvements. Lots of details everywhere. * Cleanup old code and tests. * Handle lone timestamp tokens (just recover when something bad happens). * Adding return_language example. * No ffmpeg. * Hmm. * Some corrections. * Both fast and slow. * New black. * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove print. * Undoing tests modifications. * Smaller test modifications. * Rename. * Remove maxDiff. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -82,112 +82,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
|
||||
break
|
||||
|
||||
|
||||
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
||||
"""
|
||||
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
||||
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
||||
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
||||
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
||||
properly compute the final `offset`.
|
||||
"""
|
||||
# index of the first timestamp token
|
||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||
items = []
|
||||
# approximation of the token to time ratio : ~0.2seconds
|
||||
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||
time = 0
|
||||
for seq_idx, item in enumerate(sequences):
|
||||
sequence, stride = item
|
||||
if isinstance(sequence, list):
|
||||
sequence = np.array(sequence)
|
||||
chunk_len, stride_left, stride_right = stride
|
||||
sequence = sequence.squeeze(0)
|
||||
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
||||
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
||||
sequence = sequence[begin_idx:]
|
||||
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
time -= stride_left + stride_right
|
||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||
# relevant timestamps are in the overlapping part
|
||||
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
||||
if relevant_timestamp.shape[0] > 0:
|
||||
relevant_timestamp = (
|
||||
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
||||
)
|
||||
# if a big stride is used, we need to check some of the previous items for the best overlap
|
||||
best_match = 0
|
||||
sliced_sequence = []
|
||||
for idx, previous_sequence in enumerate(reversed(items)):
|
||||
previous_tokens = previous_sequence[1:-1]
|
||||
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
||||
break # the previous sequence is too far in the past
|
||||
if len(previous_tokens) > 0:
|
||||
# find the longest common sequence between the overlapping parts
|
||||
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||
sequence[1:relevant_timestamp], previous_tokens
|
||||
)
|
||||
# don't do anything if only 1 token was matched
|
||||
if match_length > 1 and match_length > best_match:
|
||||
best_match = match_length
|
||||
best_idx = idx
|
||||
end_of_curr_sequence_idx = (
|
||||
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
||||
)
|
||||
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
||||
# if all the tokens are matched, suffix
|
||||
if index_left == 0 and match_length == len(previous_tokens):
|
||||
sliced_sequence = np.insert(
|
||||
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
||||
)
|
||||
sliced_sequence[-1] = previous_sequence[-1]
|
||||
# if part of the previous sequence is not taken
|
||||
elif index_left >= 0:
|
||||
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
||||
# let's insert the missing part of the previous sequence
|
||||
previous_slice = (
|
||||
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
||||
)
|
||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
||||
sliced_sequence[-1] += offset
|
||||
|
||||
if len(sliced_sequence) > 0:
|
||||
items[len(items) - best_idx - 1] = sliced_sequence
|
||||
items = items[: len(items) - best_idx]
|
||||
sequence = sequence[end_of_curr_sequence_idx:]
|
||||
|
||||
# sequence might have changed
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
if sum(timestamp_tokens) > 0:
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = (
|
||||
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
||||
)
|
||||
|
||||
if len(consecutive) > 0:
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
||||
sliced_tokens = sequence[last_slice:current_slice]
|
||||
duration = sliced_tokens[-1] - sliced_tokens[0]
|
||||
sliced_tokens[0] = actual_offset
|
||||
sliced_tokens[-1] = actual_offset + duration
|
||||
items.append(sliced_tokens)
|
||||
last_slice = current_slice
|
||||
|
||||
time += chunk_len
|
||||
result = []
|
||||
for i in range(len(items)):
|
||||
result += items[i].tolist()
|
||||
return result
|
||||
|
||||
|
||||
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
|
||||
seq_len_left = len(sequence_left)
|
||||
seq_len_right = len(sequence_right)
|
||||
@@ -384,6 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
ignore_warning=None,
|
||||
decoder_kwargs=None,
|
||||
return_timestamps=None,
|
||||
return_language=None,
|
||||
generate_kwargs=None,
|
||||
max_new_tokens=None,
|
||||
):
|
||||
@@ -413,6 +308,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if return_timestamps is not None:
|
||||
forward_params["return_timestamps"] = return_timestamps
|
||||
postprocess_params["return_timestamps"] = return_timestamps
|
||||
if return_language is not None:
|
||||
postprocess_params["return_language"] = return_language
|
||||
|
||||
return preprocess_params, forward_params, postprocess_params
|
||||
|
||||
@@ -580,7 +477,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
extra = model_inputs
|
||||
return {"is_last": is_last, **out, **extra}
|
||||
|
||||
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None):
|
||||
def postprocess(
|
||||
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
|
||||
):
|
||||
# Optional return types
|
||||
optional = {}
|
||||
|
||||
@@ -591,12 +490,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
|
||||
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
|
||||
|
||||
if return_language is not None and self.type != "seq2seq_whisper":
|
||||
raise ValueError("Only whisper can return language for now.")
|
||||
|
||||
final_items = []
|
||||
key = "logits" if self.type == "ctc_with_lm" else "tokens"
|
||||
stride = None
|
||||
for outputs in model_outputs:
|
||||
items = outputs[key].numpy()
|
||||
stride = outputs.pop("stride", None)
|
||||
stride = outputs.get("stride", None)
|
||||
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
|
||||
total_n, left, right = stride
|
||||
# Total_n might be < logits.shape[1]
|
||||
@@ -605,15 +507,28 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
# This won't work with left padding (which doesn't exist right now)
|
||||
right_n = total_n - right
|
||||
items = items[:, left:right_n]
|
||||
if self.type == "seq2seq_whisper" and return_timestamps and stride is not None:
|
||||
# Whisper needs the stride data
|
||||
items = [items, stride]
|
||||
final_items.append(items)
|
||||
if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps:
|
||||
|
||||
if stride and self.type == "seq2seq":
|
||||
items = _find_longest_common_sequence(final_items, self.tokenizer)
|
||||
elif stride and self.type == "seq2seq_whisper" and return_timestamps:
|
||||
items = _find_timestamp_sequence(
|
||||
final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions
|
||||
elif self.type == "seq2seq_whisper":
|
||||
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
|
||||
# Send the chunking back to seconds, it's easier to handle in whisper
|
||||
sampling_rate = self.feature_extractor.sampling_rate
|
||||
for output in model_outputs:
|
||||
if "stride" in output:
|
||||
chunk_len, stride_left, stride_right = output["stride"]
|
||||
# Go back in seconds
|
||||
chunk_len /= sampling_rate
|
||||
stride_left /= sampling_rate
|
||||
stride_right /= sampling_rate
|
||||
output["stride"] = chunk_len, stride_left, stride_right
|
||||
|
||||
text, optional = self.tokenizer._decode_asr(
|
||||
model_outputs,
|
||||
return_timestamps=return_timestamps,
|
||||
return_language=return_language,
|
||||
time_precision=time_precision,
|
||||
)
|
||||
else:
|
||||
items = np.concatenate(final_items, axis=1)
|
||||
@@ -631,14 +546,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
offsets = []
|
||||
for word, (start_offset, end_offset) in chunk_offset:
|
||||
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
|
||||
else:
|
||||
elif self.type != "seq2seq_whisper":
|
||||
skip_special_tokens = self.type != "ctc"
|
||||
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
|
||||
if return_timestamps and self.type == "seq2seq_whisper":
|
||||
offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[
|
||||
"offsets"
|
||||
]
|
||||
elif return_timestamps:
|
||||
if return_timestamps:
|
||||
offsets = self.tokenizer.decode(
|
||||
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
|
||||
)["char_offsets"]
|
||||
@@ -656,14 +567,119 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
||||
|
||||
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
|
||||
optional["chunks"] = chunks
|
||||
elif return_timestamps and self.type == "seq2seq_whisper":
|
||||
optional["chunks"] = offsets
|
||||
|
||||
extra = defaultdict(list)
|
||||
for output in model_outputs:
|
||||
output.pop("tokens", None)
|
||||
output.pop("logits", None)
|
||||
output.pop("is_last", None)
|
||||
output.pop("stride", None)
|
||||
for k, v in output.items():
|
||||
extra[k].append(v)
|
||||
return {"text": text, **optional, **extra}
|
||||
|
||||
|
||||
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
|
||||
"""
|
||||
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
|
||||
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
|
||||
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
|
||||
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
|
||||
properly compute the final `offset`.
|
||||
"""
|
||||
# index of the first timestamp token
|
||||
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
|
||||
items = []
|
||||
# approximation of the token to time ratio : ~0.2seconds
|
||||
time_precision = feature_extractor.chunk_length / max_source_positions
|
||||
time = 0
|
||||
for seq_idx, item in enumerate(sequences):
|
||||
sequence, stride = item
|
||||
if isinstance(sequence, list):
|
||||
sequence = np.array(sequence)
|
||||
chunk_len, stride_left, stride_right = stride
|
||||
sequence = sequence.squeeze(0)
|
||||
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
|
||||
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
|
||||
sequence = sequence[begin_idx:]
|
||||
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
if seq_idx != 0 and sum(timestamp_tokens) > 0:
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
|
||||
time -= stride_left + stride_right
|
||||
offset = int((time / feature_extractor.sampling_rate) / time_precision)
|
||||
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
|
||||
# relevant timestamps are in the overlapping part
|
||||
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
|
||||
if relevant_timestamp.shape[0] > 0:
|
||||
relevant_timestamp = (
|
||||
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
|
||||
)
|
||||
# if a big stride is used, we need to check some of the previous items for the best overlap
|
||||
best_match = 0
|
||||
sliced_sequence = []
|
||||
for idx, previous_sequence in enumerate(reversed(items)):
|
||||
previous_tokens = previous_sequence[1:-1]
|
||||
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
|
||||
break # the previous sequence is too far in the past
|
||||
if len(previous_tokens) > 0:
|
||||
# find the longest common sequence between the overlapping parts
|
||||
index_left, index_right, match_length = _fast_find_longest_common_sequence(
|
||||
sequence[1:relevant_timestamp], previous_tokens
|
||||
)
|
||||
# don't do anything if only 1 token was matched
|
||||
if match_length > 1 and match_length > best_match:
|
||||
best_match = match_length
|
||||
best_idx = idx
|
||||
end_of_curr_sequence_idx = (
|
||||
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
|
||||
)
|
||||
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
|
||||
# if all the tokens are matched, suffix
|
||||
if index_left == 0 and match_length == len(previous_tokens):
|
||||
sliced_sequence = np.insert(
|
||||
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
|
||||
)
|
||||
sliced_sequence[-1] = previous_sequence[-1]
|
||||
# if part of the previous sequence is not taken
|
||||
elif index_left >= 0:
|
||||
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
|
||||
# let's insert the missing part of the previous sequence
|
||||
previous_slice = (
|
||||
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
|
||||
)
|
||||
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
|
||||
sliced_sequence[-1] += offset
|
||||
|
||||
if len(sliced_sequence) > 0:
|
||||
items[len(items) - best_idx - 1] = sliced_sequence
|
||||
items = items[: len(items) - best_idx]
|
||||
sequence = sequence[end_of_curr_sequence_idx:]
|
||||
|
||||
# sequence might have changed
|
||||
timestamp_tokens = sequence >= timestamp_begin
|
||||
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
|
||||
if sum(timestamp_tokens) > 0:
|
||||
last_timestamp = np.where(timestamp_tokens)[0][-1]
|
||||
consecutive = (
|
||||
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
|
||||
)
|
||||
|
||||
if len(consecutive) > 0:
|
||||
last_slice = 0
|
||||
for current_slice in consecutive:
|
||||
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
|
||||
sliced_tokens = sequence[last_slice:current_slice]
|
||||
duration = sliced_tokens[-1] - sliced_tokens[0]
|
||||
sliced_tokens[0] = actual_offset
|
||||
sliced_tokens[-1] = actual_offset + duration
|
||||
items.append(sliced_tokens)
|
||||
last_slice = current_slice
|
||||
|
||||
time += chunk_len
|
||||
result = []
|
||||
for i in range(len(items)):
|
||||
result += items[i].tolist()
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user