[Tests] Diverse Whisper fixes (#33665)
* fix beam indices in token_timestamps * fix attention_mask in FA2 * correct translation example with the right example * correct how somes tests are using outputs + correct num_frames * fix shortform batch prev cond tests * make fix-copies * make fix-copies * take care of shifting beam indices * [run-slow] whisper * [run-slow] whisper
This commit is contained in:
@@ -1916,14 +1916,14 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
EXPECTED_TRANSCRIPT = " Mein sechster Sohn scheint, wenigstens auf den ersten Blick,"
|
||||
EXPECTED_TRANSCRIPT = " Denken Sie, soeben walten meine Gedanken bei Ihnen in Adela"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
generated_ids = model.generate(
|
||||
input_features, do_sample=False, max_length=20, language="<|de|>", task="translate"
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
EXPECTED_TRANSCRIPT = " My sixth son seems, at least at first glance, the most deeply-minded"
|
||||
EXPECTED_TRANSCRIPT = " Think, my thoughts were just rolling with you in Adelaide, and I"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
@@ -2238,7 +2238,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||
)
|
||||
|
||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
@@ -2249,7 +2249,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_large_token_timestamp_generation(self):
|
||||
@@ -2268,7 +2268,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||
)
|
||||
|
||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_OUTPUT = torch.tensor([
|
||||
@@ -2279,7 +2279,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_batch_generation(self):
|
||||
@@ -2306,9 +2306,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# task id and lang id prompts should not have timestamp tokens
|
||||
self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1])
|
||||
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
|
||||
|
||||
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
|
||||
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
|
||||
|
||||
@slow
|
||||
def test_tiny_token_timestamp_generation_longform(self):
|
||||
@@ -2799,7 +2799,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
result = model.generate(input_features, **gen_kwargs)
|
||||
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
||||
decoded = processor.batch_decode(result, skip_special_tokens=True)
|
||||
|
||||
assert decoded == EXPECTED_TEXT
|
||||
|
||||
@@ -2814,7 +2814,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
result = model.generate(input_features, **gen_kwargs)
|
||||
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
||||
decoded = processor.batch_decode(result, skip_special_tokens=True)
|
||||
|
||||
assert decoded == EXPECTED_TEXT1
|
||||
|
||||
@@ -3114,7 +3114,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
}
|
||||
|
||||
result = model.generate(**inputs, **gen_kwargs)
|
||||
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
||||
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
|
||||
|
||||
for i in range(num_samples):
|
||||
if isinstance(EXPECTED_TEXT[i], str):
|
||||
|
||||
Reference in New Issue
Block a user