add word-level timestamps to Whisper (#23205)
* let's go! * initial implementation of token-level timestamps * only return a single timestamp per token * remove token probabilities * fix return type * fix doc comment * strip special tokens * rename * revert to not stripping special tokens * only support models that have alignment_heads * add integration test * consistently name it token-level timestamps * small DTW tweak * initial support for ASR pipeline * fix pipeline doc comments * resolve token timestamps in pipeline with chunking * change warning when no final timestamp is found * return word-level timestamps * fixup * fix bug that skipped final word in each chunk * fix failing unit tests * merge punctuations into the words * also return word tokens * also return token indices * add (failing) unit test for combine_tokens_into_words * make combine_tokens_into_words private * restore OpenAI's punctuation rules * add pipeline tests * make requested changes * PR review changes * fix failing pipeline test * small stuff from PR * only return words and their timestamps, not segments * move alignment_heads into generation config * forgot to set alignment_heads in pipeline tests * tiny comment fix * grr
This commit is contained in:
committed by
GitHub
parent
0f968ddaa3
commit
cd927a4736
@@ -15,7 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
|
||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
||||
|
||||
def test_combine_tokens_into_words(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
# 'whatever "whatever" said someone, clever!?'
|
||||
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
|
||||
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
|
||||
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
|
||||
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
|
||||
output = _combine_tokens_into_words(tokenizer, encoded_input)
|
||||
self.assertEqual(expected_words, output[0])
|
||||
self.assertEqual(expected_tokens, output[1])
|
||||
self.assertEqual(expected_indices, output[2])
|
||||
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
|
||||
self.assertEqual(expected_words, output_rust[0])
|
||||
self.assertEqual(expected_tokens, output_rust[1])
|
||||
self.assertEqual(expected_indices, output_rust[2])
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
Reference in New Issue
Block a user