[Whisper] Fix word-level timestamps with bs>1 or num_beams>1 (#28114)

* fix frames

* use smaller chunk length

* correct beam search + tentative stride

* fix whisper word timestamp in batch

* add test batch generation with return token timestamps

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* clean a test

* make style + correct typo

* write clearer comments

* explain test in comment

---------

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2023-12-22 12:43:11 +00:00
committed by GitHub
parent c4df7c1668
commit 5da3db3fd5
4 changed files with 138 additions and 11 deletions

View File

@@ -674,6 +674,50 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
},
)
@slow
@require_torch
def test_whisper_word_timestamps_batched(self):
pipe = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny",
chunk_length_s=3,
return_timestamps="word",
)
data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = data[0]["audio"]
# not the same output as test_simple_whisper_asr because of chunking
EXPECTED_OUTPUT = {
"text": " Mr. Quilder is the apostle of the middle classes and we are glad to welcome his gospel.",
"chunks": [
{"text": " Mr.", "timestamp": (0.48, 0.96)},
{"text": " Quilder", "timestamp": (0.96, 1.24)},
{"text": " is", "timestamp": (1.24, 1.5)},
{"text": " the", "timestamp": (1.5, 1.72)},
{"text": " apostle", "timestamp": (1.72, 1.98)},
{"text": " of", "timestamp": (1.98, 2.32)},
{"text": " the", "timestamp": (2.32, 2.5)},
{"text": " middle", "timestamp": (2.5, 2.68)},
{"text": " classes", "timestamp": (2.68, 3.2)},
{"text": " and", "timestamp": (3.2, 3.56)},
{"text": " we", "timestamp": (3.56, 3.68)},
{"text": " are", "timestamp": (3.68, 3.8)},
{"text": " glad", "timestamp": (3.8, 4.1)},
{"text": " to", "timestamp": (4.1, 4.34)},
{"text": " welcome", "timestamp": (4.3, 4.6)},
{"text": " his", "timestamp": (4.6, 4.94)},
{"text": " gospel.", "timestamp": (4.94, 5.82)},
],
}
# batch size 1: copy the audio sample since pipeline consumes it
output = pipe(sample.copy(), batch_size=1)
self.assertDictEqual(output, EXPECTED_OUTPUT)
# batch size 2: input audio is chunked into smaller pieces so it's testing batching
output = pipe(sample, batch_size=2)
self.assertDictEqual(output, EXPECTED_OUTPUT)
@require_torch
@slow
def test_torch_speech_encoder_decoder(self):