[Whisper] Fix word-level timestamps with bs>1 or num_beams>1 (#28114)
* fix frames * use smaller chunk length * correct beam search + tentative stride * fix whisper word timestamp in batch * add test batch generation with return token timestamps * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * clean a test * make style + correct typo * write clearer comments * explain test in comment --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -1850,6 +1850,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_batch_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
model.to(torch_device)
|
||||
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||
|
||||
num_samples = 4
|
||||
num_return_sequences = 2
|
||||
|
||||
input_speech = self._load_datasamples(num_samples)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
generate_outputs = model.generate(
|
||||
input_features,
|
||||
max_length=448,
|
||||
return_timestamps=True,
|
||||
return_token_timestamps=True,
|
||||
num_beams=3,
|
||||
num_return_sequences=num_return_sequences,
|
||||
)
|
||||
|
||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||
|
||||
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
|
||||
|
||||
@slow
|
||||
def test_tiny_specaugment_librispeech(self):
|
||||
torch_device = "cpu"
|
||||
|
||||
@@ -674,6 +674,50 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_whisper_word_timestamps_batched(self):
|
||||
pipe = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny",
|
||||
chunk_length_s=3,
|
||||
return_timestamps="word",
|
||||
)
|
||||
data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
sample = data[0]["audio"]
|
||||
|
||||
# not the same output as test_simple_whisper_asr because of chunking
|
||||
EXPECTED_OUTPUT = {
|
||||
"text": " Mr. Quilder is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||
"chunks": [
|
||||
{"text": " Mr.", "timestamp": (0.48, 0.96)},
|
||||
{"text": " Quilder", "timestamp": (0.96, 1.24)},
|
||||
{"text": " is", "timestamp": (1.24, 1.5)},
|
||||
{"text": " the", "timestamp": (1.5, 1.72)},
|
||||
{"text": " apostle", "timestamp": (1.72, 1.98)},
|
||||
{"text": " of", "timestamp": (1.98, 2.32)},
|
||||
{"text": " the", "timestamp": (2.32, 2.5)},
|
||||
{"text": " middle", "timestamp": (2.5, 2.68)},
|
||||
{"text": " classes", "timestamp": (2.68, 3.2)},
|
||||
{"text": " and", "timestamp": (3.2, 3.56)},
|
||||
{"text": " we", "timestamp": (3.56, 3.68)},
|
||||
{"text": " are", "timestamp": (3.68, 3.8)},
|
||||
{"text": " glad", "timestamp": (3.8, 4.1)},
|
||||
{"text": " to", "timestamp": (4.1, 4.34)},
|
||||
{"text": " welcome", "timestamp": (4.3, 4.6)},
|
||||
{"text": " his", "timestamp": (4.6, 4.94)},
|
||||
{"text": " gospel.", "timestamp": (4.94, 5.82)},
|
||||
],
|
||||
}
|
||||
|
||||
# batch size 1: copy the audio sample since pipeline consumes it
|
||||
output = pipe(sample.copy(), batch_size=1)
|
||||
self.assertDictEqual(output, EXPECTED_OUTPUT)
|
||||
|
||||
# batch size 2: input audio is chunked into smaller pieces so it's testing batching
|
||||
output = pipe(sample, batch_size=2)
|
||||
self.assertDictEqual(output, EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_speech_encoder_decoder(self):
|
||||
|
||||
Reference in New Issue
Block a user