Whisper Timestamp processor and prediction (#20620)
* add draft logit processor * add template functions * update timesapmt processor parameters * draft script * simplify code * cleanup * fixup and clean * update pipeline * style * clean up previous idea * add tokenization utils * update tokenizer and asr output * fit whisper type * style and update test * clean test * style test * update tests * update error test * udpate code (not based on review yet) * update tokenization * update asr pipeline * update code * cleanup and update test * fmt * remove text verificatino * cleanup * cleanup * add model test * update tests * update code add docstring * update code and add docstring * fix pipeline tests * add draft logit processor add template functions update timesapmt processor parameters draft script simplify code cleanup fixup and clean update pipeline style clean up previous idea add tokenization utils update tokenizer and asr output fit whisper type style and update test clean test style test update tests update error test udpate code (not based on review yet) update tokenization update asr pipeline update code cleanup and update test fmt remove text verificatino cleanup cleanup add model test update tests update code add docstring update code and add docstring fix pipeline tests * Small update. * Fixup. * Tmp. * More support. * Making `forced_decoder_ids` non mandatory for users to set. * update and fix first bug * properly process sequence right after merge if last * tofo * allow list inputs + compute begin index better * start adding tests * add the 3 edge cases * style * format sequences * fixup * update * update * style * test passes, edge cases should be good * update last value * remove Trie * update tests and expec ted values * handle bigger chunk_length * clean tests a bit * refactor chunk iter and clean pipeline * update tests * style * refactor chunk iter and clean pipeline * upade * resolve comments * Apply suggestions from code review Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> * take stride right into account * update test expected values * Update code based on review Co-authored-by: sgugger <sylvain.gugger@gmail.com> Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com> Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This commit is contained in:
@@ -20,6 +20,8 @@ import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WhisperConfig
|
||||
from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device
|
||||
from transformers.utils import cached_property
|
||||
@@ -44,6 +46,7 @@ if is_torch_available():
|
||||
WhisperProcessor,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
|
||||
from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder
|
||||
|
||||
|
||||
@@ -1030,7 +1033,6 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_tiny_en_batched_generation(self):
|
||||
torch_device = "cuda"
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
@@ -1067,3 +1069,43 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
input_speech = np.concatenate(self._load_datasamples(4))
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
model.config.forced_decoder_ids = [(1, 50259), (2, 50359), (3, 50364)]
|
||||
timestamp_processor = [WhisperTimeStampLogitsProcessor(len(model.config.forced_decoder_ids))]
|
||||
generated_ids = model.generate(input_features, max_length=448, logits_processor=timestamp_processor).to("cpu")
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
'text': " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all,",
|
||||
'offsets': [
|
||||
{'text': ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.', 'timestamp': (0.0, 5.62)},
|
||||
{'text': " Nor is Mr. Quilter's manner less interesting than his matter.", 'timestamp': (5.62, 10.36)},
|
||||
{'text': ' He tells us that at this festive season of the year,', 'timestamp': (10.36, 14.46)},
|
||||
{'text': ' with Christmas and roast beef looming before us,', 'timestamp': (14.46, 17.76)},
|
||||
{'text': ' similes drawn from eating and its results occur most readily to the mind.', 'timestamp': (17.76, 22.8)},
|
||||
{'text': " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", 'timestamp': (22.8, 28.82)}
|
||||
]
|
||||
}
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@@ -227,3 +227,71 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
|
||||
batch_encoding = multilingual_tokenizer.batch_encode_plus(batch, padding=True).input_ids
|
||||
transcription = multilingual_tokenizer.batch_decode(batch_encoding, skip_special_tokens=True)
|
||||
self.assertListEqual(batch, transcription)
|
||||
|
||||
def test_offset_decoding(self):
|
||||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
||||
# fmt: off
|
||||
INPUT_TOKENS = [
|
||||
50258, 50259, 50359, 50364, 441, 1857, 4174, 11, 5242, 366,
|
||||
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
|
||||
293, 25730, 311, 454, 34152, 4496, 904, 50724, 50724, 366,
|
||||
382, 4048, 382, 257, 361, 18459, 13065, 13, 2221, 13,
|
||||
7145, 74, 325, 38756, 311, 29822, 7563, 412, 472, 709,
|
||||
294, 264, 51122, 51122, 912, 636, 300, 2221, 13, 2741,
|
||||
5767, 1143, 281, 7319, 702, 7798, 13, 400, 2221, 13,
|
||||
2619, 4004, 811, 2709, 702, 51449, 51449, 50257
|
||||
]
|
||||
# fmt: on
|
||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
{
|
||||
"text": (
|
||||
" Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles"
|
||||
),
|
||||
"timestamp": (0.0, 7.2),
|
||||
},
|
||||
{
|
||||
"text": (
|
||||
" are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the"
|
||||
),
|
||||
"timestamp": (7.2, 15.16),
|
||||
},
|
||||
{
|
||||
"text": " same way that Mr. Carker used to flash his teeth. And Mr. John Colier gives his",
|
||||
"timestamp": (15.16, 21.7),
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# test a single sequence with timestamps
|
||||
# fmt: off
|
||||
INPUT_TOKENS = [
|
||||
50364, 441, 1857, 4174, 11, 5242, 366,
|
||||
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
|
||||
293, 25730, 311, 454, 34152, 4496, 904, 50724
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||
self.assertEqual(
|
||||
output[0],
|
||||
{
|
||||
"text": " Lennils, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles",
|
||||
"timestamp": (0.0, 7.2),
|
||||
},
|
||||
)
|
||||
|
||||
# test a sequence without a single timestamps
|
||||
# fmt: off
|
||||
INPUT_TOKENS = [
|
||||
441, 1857, 4174, 11, 5242, 366,
|
||||
257, 1333, 295, 493, 2794, 2287, 293, 12018, 14880, 11,
|
||||
293, 25730, 311, 454, 34152, 4496, 904, 50724
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
|
||||
self.assertEqual(output, [])
|
||||
|
||||
Reference in New Issue
Block a user