[Tests] Diverse Whisper fixes (#33665)
* fix beam indices in token_timestamps * fix attention_mask in FA2 * correct translation example with the right example * correct how somes tests are using outputs + correct num_frames * fix shortform batch prev cond tests * make fix-copies * make fix-copies * take care of shifting beam indices * [run-slow] whisper * [run-slow] whisper
This commit is contained in:
@@ -291,7 +291,7 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
|
|||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None: # no matter the length, we just slice it
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = attention_mask[:, : key_states.shape[-2]]
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
|||||||
@@ -173,7 +173,9 @@ def _pad_to_max_length(
|
|||||||
|
|
||||||
|
|
||||||
class WhisperGenerationMixin(GenerationMixin):
|
class WhisperGenerationMixin(GenerationMixin):
|
||||||
def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None):
|
def _extract_token_timestamps(
|
||||||
|
self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None, num_input_ids=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to
|
||||||
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
|
map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder
|
||||||
@@ -200,11 +202,18 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
||||||
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
|
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
|
||||||
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
|
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
|
||||||
|
weight_length = weight_length if num_input_ids is None else weight_length + num_input_ids
|
||||||
|
|
||||||
|
# beam search takes `decoder_input_ids` into account in the `beam_indices` length
|
||||||
|
# but forgot to shift the beam_indices by the number of `decoder_input_ids`
|
||||||
|
beam_indices = torch.zeros_like(generate_outputs.beam_indices[:, :weight_length])
|
||||||
|
# we actually shif the beam indices here
|
||||||
|
beam_indices[:, num_input_ids:] = generate_outputs.beam_indices[:, : weight_length - num_input_ids]
|
||||||
|
|
||||||
weights = weights[:, :, :weight_length]
|
weights = weights[:, :, :weight_length]
|
||||||
|
|
||||||
# If beam index is still -1, it means that the associated token id is EOS
|
# If beam index is still -1, it means that the associated token id is EOS
|
||||||
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
|
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
|
||||||
beam_indices = generate_outputs.beam_indices[:, :weight_length]
|
|
||||||
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
|
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
|
||||||
|
|
||||||
# Select the cross attention from the right beam for each output sequences
|
# Select the cross attention from the right beam for each output sequences
|
||||||
@@ -218,8 +227,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
|
|
||||||
# make sure timestamps are as long as weights
|
# make sure timestamps are as long as weights
|
||||||
input_length = weight_length or cross_attentions[0].shape[2]
|
input_length = weight_length or cross_attentions[0].shape[2]
|
||||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)[:, : input_length + 1]
|
batch_size = generate_outputs.sequences.shape[0]
|
||||||
batch_size = timestamps.shape[0]
|
timestamps = torch.zeros(
|
||||||
|
(batch_size, input_length + 1), dtype=torch.float32, device=generate_outputs.sequences.device
|
||||||
|
)
|
||||||
|
|
||||||
if num_frames is not None:
|
if num_frames is not None:
|
||||||
# two cases:
|
# two cases:
|
||||||
@@ -239,6 +250,7 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
else:
|
else:
|
||||||
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
||||||
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
||||||
|
num_frames = num_frames.cpu() if isinstance(num_frames, (torch.Tensor)) else num_frames
|
||||||
num_frames = np.repeat(num_frames, repeat_time)
|
num_frames = np.repeat(num_frames, repeat_time)
|
||||||
|
|
||||||
if num_frames is None or isinstance(num_frames, int):
|
if num_frames is None or isinstance(num_frames, int):
|
||||||
@@ -948,7 +960,10 @@ class WhisperGenerationMixin(GenerationMixin):
|
|||||||
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
|
||||||
num_frames = getattr(generation_config, "num_frames", None)
|
num_frames = getattr(generation_config, "num_frames", None)
|
||||||
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
seek_outputs["token_timestamps"] = self._extract_token_timestamps(
|
||||||
seek_outputs, generation_config.alignment_heads, num_frames=num_frames
|
seek_outputs,
|
||||||
|
generation_config.alignment_heads,
|
||||||
|
num_frames=num_frames,
|
||||||
|
num_input_ids=decoder_input_ids.shape[-1],
|
||||||
)
|
)
|
||||||
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
|
seek_outputs["token_timestamps"] = seek_outputs["token_timestamps"][:, start_idx:]
|
||||||
|
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ class WhisperFlashAttention2(WhisperAttention):
|
|||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None: # no matter the length, we just slice it
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = attention_mask[:, : key_states.shape[-2]]
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
|||||||
@@ -1916,14 +1916,14 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe"
|
input_features, do_sample=False, max_length=20, language="<|de|>", task="transcribe"
|
||||||
)
|
)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
EXPECTED_TRANSCRIPT = " Mein sechster Sohn scheint, wenigstens auf den ersten Blick,"
|
EXPECTED_TRANSCRIPT = " Denken Sie, soeben walten meine Gedanken bei Ihnen in Adela"
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
input_features, do_sample=False, max_length=20, language="<|de|>", task="translate"
|
input_features, do_sample=False, max_length=20, language="<|de|>", task="translate"
|
||||||
)
|
)
|
||||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
EXPECTED_TRANSCRIPT = " My sixth son seems, at least at first glance, the most deeply-minded"
|
EXPECTED_TRANSCRIPT = " Think, my thoughts were just rolling with you in Adelaide, and I"
|
||||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -2238,7 +2238,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = torch.tensor([
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
@@ -2249,7 +2249,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_token_timestamp_generation(self):
|
def test_large_token_timestamp_generation(self):
|
||||||
@@ -2268,7 +2268,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
self.assertEqual(generate_outputs["sequences"].shape, generate_outputs["token_timestamps"].shape)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_OUTPUT = torch.tensor([
|
EXPECTED_OUTPUT = torch.tensor([
|
||||||
@@ -2279,7 +2279,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
self.assertTrue(torch.allclose(generate_outputs["token_timestamps"].to("cpu"), EXPECTED_OUTPUT))
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_token_timestamp_batch_generation(self):
|
def test_tiny_token_timestamp_batch_generation(self):
|
||||||
@@ -2306,9 +2306,9 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# task id and lang id prompts should not have timestamp tokens
|
# task id and lang id prompts should not have timestamp tokens
|
||||||
self.assertEqual(generate_outputs.sequences.shape[-1] - 2, generate_outputs.token_timestamps.shape[-1])
|
self.assertEqual(generate_outputs["sequences"].shape[-1] - 2, generate_outputs["token_timestamps"].shape[-1])
|
||||||
|
|
||||||
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
|
self.assertEqual(len(generate_outputs["sequences"]), num_return_sequences * num_samples)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_token_timestamp_generation_longform(self):
|
def test_tiny_token_timestamp_generation_longform(self):
|
||||||
@@ -2799,7 +2799,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
result = model.generate(input_features, **gen_kwargs)
|
result = model.generate(input_features, **gen_kwargs)
|
||||||
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
decoded = processor.batch_decode(result, skip_special_tokens=True)
|
||||||
|
|
||||||
assert decoded == EXPECTED_TEXT
|
assert decoded == EXPECTED_TEXT
|
||||||
|
|
||||||
@@ -2814,7 +2814,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
result = model.generate(input_features, **gen_kwargs)
|
result = model.generate(input_features, **gen_kwargs)
|
||||||
decoded = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
decoded = processor.batch_decode(result, skip_special_tokens=True)
|
||||||
|
|
||||||
assert decoded == EXPECTED_TEXT1
|
assert decoded == EXPECTED_TEXT1
|
||||||
|
|
||||||
@@ -3114,7 +3114,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
result = model.generate(**inputs, **gen_kwargs)
|
result = model.generate(**inputs, **gen_kwargs)
|
||||||
decoded_all = processor.batch_decode(result.sequences, skip_special_tokens=True)
|
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
|
||||||
|
|
||||||
for i in range(num_samples):
|
for i in range(num_samples):
|
||||||
if isinstance(EXPECTED_TEXT[i], str):
|
if isinstance(EXPECTED_TEXT[i], str):
|
||||||
|
|||||||
Reference in New Issue
Block a user