[Tests] Diverse Whisper fixes (#33665)

* fix beam indices in token_timestamps

* fix attention_mask in FA2

* correct translation example with the right example

* correct how somes tests are using outputs + correct num_frames

* fix shortform batch prev cond tests

* make fix-copies

* make fix-copies

* take care of shifting beam indices

* [run-slow] whisper

* [run-slow] whisper
This commit is contained in:
Yoach Lacombe
2024-10-03 15:59:01 +02:00
committed by GitHub
parent ab97a78130
commit bf0ffe3d29
4 changed files with 33 additions and 18 deletions

View File

@@ -1916,14 +1916,14 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " Mein sechster Sohn scheint, wenigstens auf den ersten Blick,"
EXPECTED_TRANSCRIPT = " Denken Sie, soeben walten meine Gedanken bei Ihnen in Adela"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
generated_ids = model.generate(
input_features, do_sample=False, max_length=20, language="<|de|>", task="translate"
)
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
EXPECTED_TRANSCRIPT = " My sixth son seems, at least at first glance, the most deeply-minded"
EXPECTED_TRANSCRIPT = " Think, my thoughts were just rolling with you in Adelaide, and I"
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
@slow
@@ -2238,7 +2238,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
@@ -2249,7 +2249,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
@slow
def test_large_token_timestamp_generation(self):
@@ -2268,7 +2268,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
# fmt: off
EXPECTED_OUTPUT = torch.tensor([
@@ -2279,7 +2279,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
])
# fmt: on
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
@slow
def test_tiny_token_timestamp_batch_generation(self):
@@ -2306,9 +2306,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
)
# task id and lang id prompts should not have timestamp tokens
self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1])
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
@slow
def test_tiny_token_timestamp_generation_longform(self):
@@ -2799,7 +2799,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT
@@ -2814,7 +2814,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
torch.manual_seed(0)
result = model.generate(input_features, **gen_kwargs)
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)
assert decoded == EXPECTED_TEXT1
@@ -3114,7 +3114,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
}
result = model.generate(**inputs, **gen_kwargs)
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
for i in range(num_samples):
if isinstance(EXPECTED_TEXT[i], str):