add word-level timestamps to Whisper (#23205)

* let's go!

* initial implementation of token-level timestamps

* only return a single timestamp per token

* remove token probabilities

* fix return type

* fix doc comment

* strip special tokens

* rename

* revert to not stripping special tokens

* only support models that have alignment_heads

* add integration test

* consistently name it token-level timestamps

* small DTW tweak

* initial support for ASR pipeline

* fix pipeline doc comments

* resolve token timestamps in pipeline with chunking

* change warning when no final timestamp is found

* return word-level timestamps

* fixup

* fix bug that skipped final word in each chunk

* fix failing unit tests

* merge punctuations into the words

* also return word tokens

* also return token indices

* add (failing) unit test for combine_tokens_into_words

* make combine_tokens_into_words private

* restore OpenAI's punctuation rules

* add pipeline tests

* make requested changes

* PR review changes

* fix failing pipeline test

* small stuff from PR

* only return words and their timestamps, not segments

* move alignment_heads into generation config

* forgot to set alignment_heads in pipeline tests

* tiny comment fix

* grr
This commit is contained in:
Matthijs Hollemans
2023-06-21 17:48:21 +02:00
committed by GitHub
parent 0f968ddaa3
commit cd927a4736
8 changed files with 456 additions and 25 deletions

View File

@@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
def test_tiny_token_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
input_speech = self._load_datasamples(4)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generate_outputs = model.generate(
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"

View File

@@ -15,7 +15,7 @@
import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from ...test_tokenization_common import TokenizerTesterMixin
@@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
def test_combine_tokens_into_words(self):
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
# 'whatever "whatever" said someone, clever!?'
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
output = _combine_tokens_into_words(tokenizer, encoded_input)
self.assertEqual(expected_words, output[0])
self.assertEqual(expected_tokens, output[1])
self.assertEqual(expected_indices, output[2])
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
self.assertEqual(expected_words, output_rust[0])
self.assertEqual(expected_tokens, output_rust[1])
self.assertEqual(expected_indices, output_rust[2])
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"

View File

@@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
},
)
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
res = pipe(sample["audio"]["array"], return_timestamps="word")
# fmt: off
# Note that the word-level timestamps predicted here are pretty bad.
self.assertEqual(
res,
{
"text": " Conquered returned to its place amidst the tents.",
"chunks": [
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
{'text': ' returned', 'timestamp': (29.9, 29.9)},
{'text': ' to', 'timestamp': (29.9, 29.9)},
{'text': ' its', 'timestamp': (29.9, 29.9)},
{'text': ' place', 'timestamp': (29.9, 29.9)},
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
{'text': ' the', 'timestamp': (29.9, 29.9)},
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
]
}
)
# fmt: on
@require_torch
@slow
@@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
],
},
)
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
output = speech_recognizer(filename, return_timestamps="word")
# fmt: off
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
{'text': ' is', 'timestamp': (1.18, 1.44)},
{'text': ' the', 'timestamp': (1.44, 1.58)},
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
{'text': ' of', 'timestamp': (1.98, 2.3)},
{'text': ' the', 'timestamp': (2.3, 2.46)},
{'text': ' middle', 'timestamp': (2.46, 2.56)},
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
{'text': ' and', 'timestamp': (3.38, 3.52)},
{'text': ' we', 'timestamp': (3.52, 3.6)},
{'text': ' are', 'timestamp': (3.6, 3.72)},
{'text': ' glad', 'timestamp': (3.72, 4.0)},
{'text': ' to', 'timestamp': (4.0, 4.26)},
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
{'text': ' his', 'timestamp': (4.54, 4.92)},
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
],
},
)
# fmt: on
@slow
@require_torch