add word-level timestamps to Whisper (#23205)

* let's go!

* initial implementation of token-level timestamps

* only return a single timestamp per token

* remove token probabilities

* fix return type

* fix doc comment

* strip special tokens

* rename

* revert to not stripping special tokens

* only support models that have alignment_heads

* add integration test

* consistently name it token-level timestamps

* small DTW tweak

* initial support for ASR pipeline

* fix pipeline doc comments

* resolve token timestamps in pipeline with chunking

* change warning when no final timestamp is found

* return word-level timestamps

* fixup

* fix bug that skipped final word in each chunk

* fix failing unit tests

* merge punctuations into the words

* also return word tokens

* also return token indices

* add (failing) unit test for combine_tokens_into_words

* make combine_tokens_into_words private

* restore OpenAI's punctuation rules

* add pipeline tests

* make requested changes

* PR review changes

* fix failing pipeline test

* small stuff from PR

* only return words and their timestamps, not segments

* move alignment_heads into generation config

* forgot to set alignment_heads in pipeline tests

* tiny comment fix

* grr
This commit is contained in:
Matthijs Hollemans
2023-06-21 17:48:21 +02:00
committed by GitHub
parent 0f968ddaa3
commit cd927a4736
8 changed files with 456 additions and 25 deletions

View File

@@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
# Note that the word-level timestamps predicted here are pretty bad.
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
{'text': ' returned', 'timestamp': (29.9, 29.9)},
{'text': ' to', 'timestamp': (29.9, 29.9)},
{'text': ' its', 'timestamp': (29.9, 29.9)},
{'text': ' place', 'timestamp': (29.9, 29.9)},
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
{'text': ' the', 'timestamp': (29.9, 29.9)},
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
]
}
)
# fmt: on
@require_torch
@slow
@@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
# fmt: off
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
)
# fmt: on
@slow
@require_torch