Refactor whisper asr pipeline to include language too. (#21427)
* [WIP] whisper refacto to support language output. * Handling merges. * A bit more cleanup and comments. * Many improvements. Lots of details everywhere. * Cleanup old code and tests. * Handle lone timestamp tokens (just recover when something bad happens). * Adding return_language example. * No ffmpeg. * Hmm. * Some corrections. * Both fast and slow. * New black. * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Remove print. * Undoing tests modifications. * Smaller test modifications. * Rename. * Remove maxDiff. --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -115,6 +116,84 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
|
||||
)
|
||||
|
||||
def test_output_offsets(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
previous_sequence = [51492, 406, 3163, 1953, 466, 13, 51612, 51612]
|
||||
self.assertEqual(
|
||||
tokenizer.decode(previous_sequence, output_offsets=True),
|
||||
{
|
||||
"text": " not worth thinking about.",
|
||||
"offsets": [{"text": " not worth thinking about.", "timestamp": (22.56, 24.96)}],
|
||||
},
|
||||
)
|
||||
|
||||
# Merge when the previous sequence is a suffix of the next sequence
|
||||
# fmt: off
|
||||
next_sequences_1 = [50364, 295, 6177, 3391, 11, 19817, 3337, 507, 307, 406, 3163, 1953, 466, 13, 50614, 50614, 2812, 9836, 14783, 390, 6263, 538, 257, 1359, 11, 8199, 6327, 1090, 322, 702, 7443, 13, 50834, 50257]
|
||||
# fmt: on
|
||||
self.assertEqual(
|
||||
tokenizer.decode(next_sequences_1, output_offsets=True),
|
||||
{
|
||||
"text": (
|
||||
" of spectators, retrievality is not worth thinking about. His instant panic was followed by a"
|
||||
" small, sharp blow high on his chest.<|endoftext|>"
|
||||
),
|
||||
"offsets": [
|
||||
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (5.0, 9.4),
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_find_longest_common_subsequence(self):
|
||||
previous_sequence = [1, 2, 3]
|
||||
next_sequence = [2, 3, 4, 5]
|
||||
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5])
|
||||
|
||||
# Now previous is larger than next.
|
||||
# We merge what we can and remove the extra right side of the left sequence
|
||||
previous_sequence = [1, 2, 3, 4, 5, 6, 7]
|
||||
next_sequence = [2, 3, 4, 5]
|
||||
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5])
|
||||
|
||||
# Nothing in common
|
||||
previous_sequence = [1, 2, 3]
|
||||
next_sequence = [4, 5, 6]
|
||||
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
# Some errors in the overlap.
|
||||
# We take from previous on the left, from the next on the right of the overlap
|
||||
previous_sequence = [1, 2, 3, 4, 99]
|
||||
next_sequence = [2, 98, 4, 5, 6]
|
||||
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6])
|
||||
|
||||
# We take from previous on the left, from the next on the right of the overlap
|
||||
previous_sequence = [1, 2, 99, 4, 5]
|
||||
next_sequence = [2, 3, 4, 98, 6]
|
||||
merge = _find_longest_common_sequence([previous_sequence, next_sequence])
|
||||
self.assertEqual(merge, [1, 2, 99, 4, 98, 6])
|
||||
|
||||
# This works on 3 sequences
|
||||
seq1 = [1, 2, 3]
|
||||
seq2 = [2, 3, 4]
|
||||
seq3 = [3, 4, 5]
|
||||
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5])
|
||||
|
||||
# This works on 3 sequences with errors
|
||||
seq1 = [1, 2, 3, 98, 5]
|
||||
seq2 = [2, 99, 4, 5, 6, 7]
|
||||
seq3 = [4, 97, 6, 7, 8]
|
||||
merge = _find_longest_common_sequence([seq1, seq2, seq3])
|
||||
self.assertEqual(merge, [1, 2, 3, 4, 5, 6, 7, 8])
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
@@ -538,7 +538,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
"tight-loan cloth that was the only garment he wore, the "
|
||||
"cut"
|
||||
),
|
||||
"timestamp": (5.5, 11.94),
|
||||
"timestamp": (5.5, 11.95),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
@@ -546,15 +546,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
"overstrained eyes, even the soaring arena around him "
|
||||
"with"
|
||||
),
|
||||
"timestamp": (11.94, 19.6),
|
||||
"timestamp": (11.95, 19.61),
|
||||
},
|
||||
{
|
||||
"text": " the thousands of spectators, retrievality is not worth thinking about.",
|
||||
"timestamp": (19.6, 26.66),
|
||||
"timestamp": (19.61, 25.0),
|
||||
},
|
||||
{
|
||||
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
|
||||
"timestamp": (26.66, 31.06),
|
||||
"timestamp": (25.0, 29.4),
|
||||
},
|
||||
],
|
||||
"text": (
|
||||
|
||||
Reference in New Issue
Block a user