Refactor whisper asr pipeline to include language too. (#21427)

* [WIP] whisper refacto to support language output.

* Handling merges.

* A bit more cleanup and comments.

* Many improvements.

Lots of details everywhere.

* Cleanup old code and tests.

* Handle lone timestamp tokens (just recover when something bad happens).

* Adding return_language example.

* No ffmpeg.

* Hmm.

* Some corrections.

* Both fast and slow.

* New black.

* Update src/transformers/models/whisper/tokenization_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/whisper/tokenization_whisper.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Remove print.

* Undoing tests modifications.

* Smaller test modifications.

* Rename.

* Remove maxDiff.

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2023-03-02 18:12:19 +01:00
committed by GitHub
parent 8e5a1b2abb
commit 1325459105
5 changed files with 518 additions and 128 deletions

View File

@@ -15,6 +15,7 @@
import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -115,6 +116,84 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
)
def test_output_offsets(self):
tokenizer = self.get_tokenizer()
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
self.assertEqual(
tokenizer.decode(previous_sequence, output_offsets=True),
{
"text": " not worth thinking about.",
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
},
)
# Merge when the previous sequence is a suffix of the next sequence
# fmt: off
next_sequences_1 = [50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
# fmt: on
self.assertEqual(
tokenizer.decode(next_sequences_1, output_offsets=True),
{
"text": (
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
" small, sharp blow high on his chest.<|endoftext|>"
),
"offsets": [
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (5.0, 9.4),
},
],
},
)
def test_find_longest_common_subsequence(self):
previous_sequence = [1, 2, 3]
next_sequence = [2, 3, 4, 5]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5])
# Now previous is larger than next.
# We merge what we can and remove the extra right side of the left sequence
previous_sequence = [1, 2, 3, 4, 5, 6, 7]
next_sequence = [2, 3, 4, 5]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5])
# Nothing in common
previous_sequence = [1, 2, 3]
next_sequence = [4, 5, 6]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
# Some errors in the overlap.
# We take from previous on the left, from the next on the right of the overlap
previous_sequence = [1, 2, 3, 4, 99]
next_sequence = [2, 98, 4, 5, 6]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
# We take from previous on the left, from the next on the right of the overlap
previous_sequence = [1, 2, 99, 4, 5]
next_sequence = [2, 3, 4, 98, 6]
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
self.assertEqual(merge, [1, 2, 99, 4, 98, 6])
# This works on 3 sequences
seq1 = [1, 2, 3]
seq2 = [2, 3, 4]
seq3 = [3, 4, 5]
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5])
# This works on 3 sequences with errors
seq1 = [1, 2, 3, 98, 5]
seq2 = [2, 99, 4, 5, 6, 7]
seq3 = [4, 97, 6, 7, 8]
merge = _find_longest_common_sequence([seq1, seq2, seq3])
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"