[Whisper] Fix word-level timestamps with bs>1 or num_beams>1 (#28114)

* fix frames

* use smaller chunk length

* correct beam search + tentative stride

* fix whisper word timestamp in batch

* add test batch generation with return token timestamps

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* clean a test

* make style + correct typo

* write clearer comments

* explain test in comment

---------

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Yoach Lacombe
2023-12-22 12:43:11 +00:00
committed by GitHub
parent c4df7c1668
commit 5da3db3fd5
4 changed files with 138 additions and 11 deletions

View File

@@ -1850,6 +1850,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
@slow
def test_tiny_token_timestamp_batch_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
num_samples = 4
num_return_sequences = 2
input_speech = self._load_datasamples(num_samples)
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
torch_device
)
generate_outputs = model.generate(
input_features,
max_length=448,
return_timestamps=True,
return_token_timestamps=True,
num_beams=3,
num_return_sequences=num_return_sequences,
)
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
@slow
def test_tiny_specaugment_librispeech(self):
torch_device = "cpu"