Refactor whisper asr pipeline to include language too. (#21427)

* [WIP] whisper refacto to support language output.

* Handling merges.

* A bit more cleanup and comments.

* Many improvements.

Lots of details everywhere.

* Cleanup old code and tests.

* Handle lone timestamp tokens (just recover when something bad happens).

* Adding return_language example.

* No ffmpeg.

* Hmm.

* Some corrections.

* Both fast and slow.

* New black.

* Update src/transformers/models/whisper/tokenization_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/whisper/tokenization_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove print.

* Undoing tests modifications.

* Smaller test modifications.

* Rename.

* Remove maxDiff.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-03-02 18:12:19 +01:00
committed by GitHub
parent 8e5a1b2abb
commit 1325459105
5 changed files with 518 additions and 128 deletions

View File

@@ -82,112 +82,6 @@ def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right,
break
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
"""
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
properly compute the final `offset`.
"""
# index of the first timestamp token
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
items = []
# approximation of the token to time ratio : ~0.2seconds
time_precision = feature_extractor.chunk_length / max_source_positions
time = 0
for seq_idx, item in enumerate(sequences):
sequence, stride = item
if isinstance(sequence, list):
sequence = np.array(sequence)
chunk_len, stride_left, stride_right = stride
sequence = sequence.squeeze(0)
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
sequence = sequence[begin_idx:]
timestamp_tokens = sequence >= timestamp_begin
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
# relevant timestamps are in the overlapping part
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
if relevant_timestamp.shape[0] > 0:
relevant_timestamp = (
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
)
# if a big stride is used, we need to check some of the previous items for the best overlap
best_match = 0
sliced_sequence = []
for idx, previous_sequence in enumerate(reversed(items)):
previous_tokens = previous_sequence[1:-1]
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
break # the previous sequence is too far in the past
if len(previous_tokens) > 0:
# find the longest common sequence between the overlapping parts
index_left, index_right, match_length = _fast_find_longest_common_sequence(
sequence[1:relevant_timestamp], previous_tokens
)
# don't do anything if only 1 token was matched
if match_length > 1 and match_length > best_match:
best_match = match_length
best_idx = idx
end_of_curr_sequence_idx = (
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
)
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
# if all the tokens are matched, suffix
if index_left == 0 and match_length == len(previous_tokens):
sliced_sequence = np.insert(
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
)
sliced_sequence[-1] = previous_sequence[-1]
# if part of the previous sequence is not taken
elif index_left >= 0:
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
# let's insert the missing part of the previous sequence
previous_slice = (
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
)
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
sliced_sequence[-1] += offset
if len(sliced_sequence) > 0:
items[len(items) - best_idx - 1] = sliced_sequence
items = items[: len(items) - best_idx]
sequence = sequence[end_of_curr_sequence_idx:]
# sequence might have changed
timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if sum(timestamp_tokens) > 0:
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = (
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
)
if len(consecutive) > 0:
last_slice = 0
for current_slice in consecutive:
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
sliced_tokens = sequence[last_slice:current_slice]
duration = sliced_tokens[-1] - sliced_tokens[0]
sliced_tokens[0] = actual_offset
sliced_tokens[-1] = actual_offset + duration
items.append(sliced_tokens)
last_slice = current_slice
time += chunk_len
result = []
for i in range(len(items)):
result += items[i].tolist()
return result
def _fast_find_longest_common_sequence(sequence_left, sequence_right):
seq_len_left = len(sequence_left)
seq_len_right = len(sequence_right)
@@ -384,6 +278,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
ignore_warning=None,
decoder_kwargs=None,
return_timestamps=None,
return_language=None,
generate_kwargs=None,
max_new_tokens=None,
):
@@ -413,6 +308,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps is not None:
forward_params["return_timestamps"] = return_timestamps
postprocess_params["return_timestamps"] = return_timestamps
if return_language is not None:
postprocess_params["return_language"] = return_language
return preprocess_params, forward_params, postprocess_params
@@ -580,7 +477,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
extra = model_inputs
return {"is_last": is_last, **out, **extra}
def postprocess(self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None):
def postprocess(
self, model_outputs, decoder_kwargs: Optional[Dict] = None, return_timestamps=None, return_language=None
):
# Optional return types
optional = {}
@@ -591,12 +490,15 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
if return_timestamps in {"char", "words"} and self.type == "seq2seq_whisper":
raise ValueError("Whisper cannot return `char` nor `words` timestamps, use `True` instead.")
if return_language is not None and self.type != "seq2seq_whisper":
raise ValueError("Only whisper can return language for now.")
final_items = []
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
stride = outputs.pop("stride", None)
stride = outputs.get("stride", None)
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
# Total_n might be < logits.shape[1]
@@ -605,15 +507,28 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# This won't work with left padding (which doesn't exist right now)
right_n = total_n - right
items = items[:, left:right_n]
if self.type == "seq2seq_whisper" and return_timestamps and stride is not None:
# Whisper needs the stride data
items = [items, stride]
final_items.append(items)
if stride and self.type in {"seq2seq", "seq2seq_whisper"} and not return_timestamps:
if stride and self.type == "seq2seq":
items = _find_longest_common_sequence(final_items, self.tokenizer)
elif stride and self.type == "seq2seq_whisper" and return_timestamps:
items = _find_timestamp_sequence(
final_items, self.tokenizer, self.feature_extractor, self.model.config.max_source_positions
elif self.type == "seq2seq_whisper":
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
# Send the chunking back to seconds, it's easier to handle in whisper
sampling_rate = self.feature_extractor.sampling_rate
for output in model_outputs:
if "stride" in output:
chunk_len, stride_left, stride_right = output["stride"]
# Go back in seconds
chunk_len /= sampling_rate
stride_left /= sampling_rate
stride_right /= sampling_rate
output["stride"] = chunk_len, stride_left, stride_right
text, optional = self.tokenizer._decode_asr(
model_outputs,
return_timestamps=return_timestamps,
return_language=return_language,
time_precision=time_precision,
)
else:
items = np.concatenate(final_items, axis=1)
@@ -631,14 +546,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
offsets = []
for word, (start_offset, end_offset) in chunk_offset:
offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
else:
elif self.type != "seq2seq_whisper":
skip_special_tokens = self.type != "ctc"
text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
if return_timestamps and self.type == "seq2seq_whisper":
offsets = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens, output_offsets=True)[
"offsets"
]
elif return_timestamps:
if return_timestamps:
offsets = self.tokenizer.decode(
items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
)["char_offsets"]
@@ -656,14 +567,119 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
optional["chunks"] = chunks
elif return_timestamps and self.type == "seq2seq_whisper":
optional["chunks"] = offsets
extra = defaultdict(list)
for output in model_outputs:
output.pop("tokens", None)
output.pop("logits", None)
output.pop("is_last", None)
output.pop("stride", None)
for k, v in output.items():
extra[k].append(v)
return {"text": text, **optional, **extra}
def _find_timestamp_sequence(sequences, tokenizer, feature_extractor, max_source_positions):
"""
Computes the final sequences by merging the end of the nth sequence with the beginning of the n+1th sequence. Since
`WhisperForConditionalGeneration` produces the timestamps pairwise, we filter the consecutive timestamps and only
iterate over them. We keep track of the `time` which indicates the actual starting time of the chunk that is
processed. We need to make sure to offset the timestamps tokens by the `time` in order for the tokenizer to
properly compute the final `offset`.
"""
# index of the first timestamp token
timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
items = []
# approximation of the token to time ratio : ~0.2seconds
time_precision = feature_extractor.chunk_length / max_source_positions
time = 0
for seq_idx, item in enumerate(sequences):
sequence, stride = item
if isinstance(sequence, list):
sequence = np.array(sequence)
chunk_len, stride_left, stride_right = stride
sequence = sequence.squeeze(0)
# get rid of the `forced_decoder_idx` that are use to parametrize the generation
begin_idx = np.where(sequence == timestamp_begin)[0][0] if timestamp_begin in sequence else 0
sequence = sequence[begin_idx:]
timestamp_tokens = sequence >= timestamp_begin
if seq_idx != 0 and sum(timestamp_tokens) > 0:
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = np.append(consecutive, last_timestamp) if last_timestamp not in consecutive else consecutive
time -= stride_left + stride_right
offset = int((time / feature_extractor.sampling_rate) / time_precision)
overlap_time = int((stride_left / feature_extractor.sampling_rate) / time_precision)
# relevant timestamps are in the overlapping part
relevant_timestamp = np.where(sequence[consecutive] >= timestamp_begin + overlap_time)[0]
if relevant_timestamp.shape[0] > 0:
relevant_timestamp = (
consecutive[relevant_timestamp[0] - 1] if relevant_timestamp[0] > 0 else consecutive[0]
)
# if a big stride is used, we need to check some of the previous items for the best overlap
best_match = 0
sliced_sequence = []
for idx, previous_sequence in enumerate(reversed(items)):
previous_tokens = previous_sequence[1:-1]
if previous_sequence[0] < (timestamp_begin + offset - overlap_time) and idx != 0:
break # the previous sequence is too far in the past
if len(previous_tokens) > 0:
# find the longest common sequence between the overlapping parts
index_left, index_right, match_length = _fast_find_longest_common_sequence(
sequence[1:relevant_timestamp], previous_tokens
)
# don't do anything if only 1 token was matched
if match_length > 1 and match_length > best_match:
best_match = match_length
best_idx = idx
end_of_curr_sequence_idx = (
np.where(sequence[index_left + 1 :] >= timestamp_begin)[0][0] + 1
)
end_of_curr_sequence_idx = end_of_curr_sequence_idx + 1 + index_left
# if all the tokens are matched, suffix
if index_left == 0 and match_length == len(previous_tokens):
sliced_sequence = np.insert(
sequence[index_left + 1 : end_of_curr_sequence_idx], 0, previous_sequence[0]
)
sliced_sequence[-1] = previous_sequence[-1]
# if part of the previous sequence is not taken
elif index_left >= 0:
sliced_sequence = sequence[index_left + 1 : end_of_curr_sequence_idx]
# let's insert the missing part of the previous sequence
previous_slice = (
previous_sequence[: index_right + 1] if index_right > 0 else [previous_sequence[0]]
)
sliced_sequence = np.insert(sliced_sequence, 0, previous_slice)
sliced_sequence[-1] += offset
if len(sliced_sequence) > 0:
items[len(items) - best_idx - 1] = sliced_sequence
items = items[: len(items) - best_idx]
sequence = sequence[end_of_curr_sequence_idx:]
# sequence might have changed
timestamp_tokens = sequence >= timestamp_begin
consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
if sum(timestamp_tokens) > 0:
last_timestamp = np.where(timestamp_tokens)[0][-1]
consecutive = (
np.append(consecutive, last_timestamp + 1) if last_timestamp not in consecutive else consecutive
)
if len(consecutive) > 0:
last_slice = 0
for current_slice in consecutive:
actual_offset = items[-1][-1] if seq_idx != 0 or last_slice != 0 else sequence[0]
sliced_tokens = sequence[last_slice:current_slice]
duration = sliced_tokens[-1] - sliced_tokens[0]
sliced_tokens[0] = actual_offset
sliced_tokens[-1] = actual_offset + duration
items.append(sliced_tokens)
last_slice = current_slice
time += chunk_len
result = []
for i in range(len(items)):
result += items[i].tolist()
return result