Incorrect Whisper long-form decoding timestamps (#32003)
* fix lo form timestamps in decode_batch * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * add test * make style * fix copies * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/whisper/processing_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/whisper/tokenization_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * apply review suggestions * fix * fix copies * fix * Update src/transformers/models/whisper/tokenization_whisper_fast.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix-copies --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -2001,6 +2001,72 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_longform_timestamps_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
|
||||
sample = self._load_datasamples(1)
|
||||
input_speech = np.concatenate(sample * 10)
|
||||
|
||||
input_features = processor(input_speech, return_tensors="pt", truncation=False, sampling_rate=16_000)
|
||||
input_features = input_features.to(torch_device)
|
||||
|
||||
generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel. Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"offsets": [
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (0.0, 6.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (6.0, 12.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (12.0, 18.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (18.0, 24.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (24.0, 29.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (29.0, 35.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (35.0, 41.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (41.0, 47.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (47.0, 53.0),
|
||||
},
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"timestamp": (53.0, 58.20000076293945),
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_timestamp_generation(self):
|
||||
set_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user