add word-level timestamps to Whisper (#23205)
* let's go! * initial implementation of token-level timestamps * only return a single timestamp per token * remove token probabilities * fix return type * fix doc comment * strip special tokens * rename * revert to not stripping special tokens * only support models that have alignment_heads * add integration test * consistently name it token-level timestamps * small DTW tweak * initial support for ASR pipeline * fix pipeline doc comments * resolve token timestamps in pipeline with chunking * change warning when no final timestamp is found * return word-level timestamps * fixup * fix bug that skipped final word in each chunk * fix failing unit tests * merge punctuations into the words * also return word tokens * also return token indices * add (failing) unit test for combine_tokens_into_words * make combine_tokens_into_words private * restore OpenAI's punctuation rules * add pipeline tests * make requested changes * PR review changes * fix failing pipeline test * small stuff from PR * only return words and their timestamps, not segments * move alignment_heads into generation config * forgot to set alignment_heads in pipeline tests * tiny comment fix * grr
This commit is contained in:
committed by
GitHub
parent
0f968ddaa3
commit
cd927a4736
@@ -1436,6 +1436,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
generate_outputs = model.generate(
|
||||
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||
)
|
||||
|
||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.4800, 0.8200, 0.9600, 1.1200, 1.1200, 1.2200, 1.5000, 1.7200, 2.0000, 2.3400, 2.5000, 2.6600, 3.1800, 3.5600, 3.6800, 3.8000, 4.1000, 4.3000, 4.5800, 4.9400, 5.3800, 12.4200, 12.8400, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9200, 26.9400, 26.9400, 26.9400, 26.9400, 29.8400 ],
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.5200, 0.9000, 1.1400, 1.4200, 1.5200, 1.6800, 1.6800, 1.8800, 2.1000, 2.2200, 2.6200, 3.1400, 3.5800, 3.9600, 4.4000, 17.3000, 17.3000, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7200, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 26.7400, 28.0000 ],
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7600, 1.0000, 1.4200, 1.8000, 1.9400, 2.1800, 2.5200, 3.0200, 3.3200, 3.5400, 3.9400, 4.5600, 4.9200, 5.2800, 5.5600, 5.9000, 6.1600, 6.3000, 6.4800, 6.4800, 6.6400, 7.8200, 7.9600, 8.2200, 8.6000, 8.9200, 9.2200, 9.5200, 9.7200, 10.0600, 10.5400, 10.8800, 11.2600, 11.5400, 11.7400, 12.0800, 15.6800, 15.6800],
|
||||
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7400, 1.0400, 1.3200, 1.6800, 2.1400, 2.4800, 2.7800, 3.0800, 3.1600, 3.4000, 3.6000, 4.0200, 4.2200, 4.8600, 5.2400, 5.7400, 6.3400, 6.6200, 6.7600, 6.7600, 6.8600, 7.2400, 7.4200, 7.6800, 7.9200, 8.4800, 8.7600, 9.2000, 9.2000, 9.4200, 15.8200, 15.8200, 29.6400, 29.6600, 29.6600, 29.6600, 29.6600, 29.7600]
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
|
||||
from transformers.models.whisper.tokenization_whisper import _find_longest_common_sequence
|
||||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from ...test_tokenization_common import TokenizerTesterMixin
|
||||
@@ -255,6 +255,24 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertListEqual(tokenizer_prompt_ids.tolist(), fast_tokenizer_prompt_ids.tolist())
|
||||
|
||||
def test_combine_tokens_into_words(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
rust_tokenizer = self.get_rust_tokenizer()
|
||||
|
||||
# 'whatever "whatever" said someone, clever!?'
|
||||
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
|
||||
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
|
||||
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
|
||||
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
|
||||
output = _combine_tokens_into_words(tokenizer, encoded_input)
|
||||
self.assertEqual(expected_words, output[0])
|
||||
self.assertEqual(expected_tokens, output[1])
|
||||
self.assertEqual(expected_indices, output[2])
|
||||
output_rust = _combine_tokens_into_words(rust_tokenizer, encoded_input)
|
||||
self.assertEqual(expected_words, output_rust[0])
|
||||
self.assertEqual(expected_tokens, output_rust[1])
|
||||
self.assertEqual(expected_indices, output_rust[2])
|
||||
|
||||
|
||||
class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
checkpoint_name = "openai/whisper-small.en"
|
||||
|
||||
@@ -316,6 +316,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
||||
},
|
||||
)
|
||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||
# fmt: off
|
||||
# Note that the word-level timestamps predicted here are pretty bad.
|
||||
self.assertEqual(
|
||||
res,
|
||||
{
|
||||
"text": " Conquered returned to its place amidst the tents.",
|
||||
"chunks": [
|
||||
{'text': ' Conquered', 'timestamp': (29.78, 29.9)},
|
||||
{'text': ' returned', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' to', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' its', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' place', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' amidst', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' the', 'timestamp': (29.9, 29.9)},
|
||||
{'text': ' tents.', 'timestamp': (29.9, 29.9)}
|
||||
]
|
||||
}
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -699,6 +720,35 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
],
|
||||
},
|
||||
)
|
||||
speech_recognizer.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
output = speech_recognizer(filename, return_timestamps="word")
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"chunks": [
|
||||
{'text': ' Mr.', 'timestamp': (0.0, 1.02)},
|
||||
{'text': ' Quilter', 'timestamp': (1.02, 1.18)},
|
||||
{'text': ' is', 'timestamp': (1.18, 1.44)},
|
||||
{'text': ' the', 'timestamp': (1.44, 1.58)},
|
||||
{'text': ' apostle', 'timestamp': (1.58, 1.98)},
|
||||
{'text': ' of', 'timestamp': (1.98, 2.3)},
|
||||
{'text': ' the', 'timestamp': (2.3, 2.46)},
|
||||
{'text': ' middle', 'timestamp': (2.46, 2.56)},
|
||||
{'text': ' classes,', 'timestamp': (2.56, 3.38)},
|
||||
{'text': ' and', 'timestamp': (3.38, 3.52)},
|
||||
{'text': ' we', 'timestamp': (3.52, 3.6)},
|
||||
{'text': ' are', 'timestamp': (3.6, 3.72)},
|
||||
{'text': ' glad', 'timestamp': (3.72, 4.0)},
|
||||
{'text': ' to', 'timestamp': (4.0, 4.26)},
|
||||
{'text': ' welcome', 'timestamp': (4.26, 4.54)},
|
||||
{'text': ' his', 'timestamp': (4.54, 4.92)},
|
||||
{'text': ' gospel.', 'timestamp': (4.92, 6.66)},
|
||||
],
|
||||
},
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user