[Whisper] Fix word-level timestamps with bs>1 or num_beams>1 (#28114)
* fix frames * use smaller chunk length * correct beam search + tentative stride * fix whisper word timestamp in batch * add test batch generation with return token timestamps * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * clean a test * make style + correct typo * write clearer comments * explain test in comment --------- Co-authored-by: sanchit-gandhi <sanchit@huggingface.co> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -2224,6 +2224,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
if return_token_timestamps:
|
if return_token_timestamps:
|
||||||
kwargs["output_attentions"] = True
|
kwargs["output_attentions"] = True
|
||||||
return_dict_in_generate = True
|
return_dict_in_generate = True
|
||||||
|
kwargs["output_scores"] = True
|
||||||
|
|
||||||
if getattr(generation_config, "task", None) == "translate":
|
if getattr(generation_config, "task", None) == "translate":
|
||||||
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
logger.warning("Token-level timestamps may not be reliable for task 'translate'.")
|
||||||
@@ -2555,22 +2556,72 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
|
|||||||
# of shape (batch size, num selected, output length, input length).
|
# of shape (batch size, num selected, output length, input length).
|
||||||
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads])
|
||||||
weights = weights.permute([1, 0, 2, 3])
|
weights = weights.permute([1, 0, 2, 3])
|
||||||
if num_frames is not None:
|
|
||||||
weights = weights[..., : num_frames // 2]
|
|
||||||
|
|
||||||
# Normalize and smoothen the weights.
|
if "beam_indices" in generate_outputs:
|
||||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
# If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths
|
||||||
weights = (weights - mean) / std
|
# since the beam search strategy chooses the most probable sequences at the end of the search.
|
||||||
weights = _median_filter(weights, self.config.median_filter_width)
|
# In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length
|
||||||
|
weight_length = (generate_outputs.beam_indices != -1).sum(-1).max()
|
||||||
|
weights = weights[:, :, :weight_length]
|
||||||
|
|
||||||
# Average the different cross-attention heads.
|
# If beam index is still -1, it means that the associated token id is EOS
|
||||||
matrix = weights.mean(dim=1)
|
# We need to replace the index with 0 since index_select gives an error if any of the indexes is -1.
|
||||||
|
beam_indices = generate_outputs.beam_indices[:, :weight_length]
|
||||||
|
beam_indices = beam_indices.masked_fill(beam_indices == -1, 0)
|
||||||
|
|
||||||
|
# Select the cross attention from the right beam for each output sequences
|
||||||
|
weights = torch.stack(
|
||||||
|
[
|
||||||
|
torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i])
|
||||||
|
for i in range(beam_indices.shape[1])
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32)
|
||||||
|
batch_size = timestamps.shape[0]
|
||||||
|
|
||||||
|
if num_frames is not None:
|
||||||
|
# two cases:
|
||||||
|
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
|
||||||
|
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
|
||||||
|
|
||||||
|
# we're using np.unique because num_frames can be int/list/tuple
|
||||||
|
if len(np.unique(num_frames)) == 1:
|
||||||
|
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
|
||||||
|
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]
|
||||||
|
|
||||||
|
weights = weights[..., : num_frames // 2]
|
||||||
|
else:
|
||||||
|
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
|
||||||
|
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
|
||||||
|
num_frames = np.repeat(num_frames, repeat_time)
|
||||||
|
|
||||||
|
if num_frames is None or isinstance(num_frames, int):
|
||||||
|
# Normalize and smoothen the weights.
|
||||||
|
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
weights = (weights - mean) / std
|
||||||
|
weights = _median_filter(weights, self.config.median_filter_width)
|
||||||
|
|
||||||
|
# Average the different cross-attention heads.
|
||||||
|
weights = weights.mean(dim=1)
|
||||||
|
|
||||||
# Perform dynamic time warping on each element of the batch.
|
# Perform dynamic time warping on each element of the batch.
|
||||||
for batch_idx in range(timestamps.shape[0]):
|
for batch_idx in range(batch_size):
|
||||||
text_indices, time_indices = _dynamic_time_warping(-matrix[batch_idx].double().cpu().numpy())
|
if num_frames is not None and isinstance(num_frames, (tuple, list)):
|
||||||
|
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
|
||||||
|
|
||||||
|
# Normalize and smoothen the weights.
|
||||||
|
std, mean = torch.std_mean(matrix, dim=-2, keepdim=True, unbiased=False)
|
||||||
|
matrix = (matrix - mean) / std
|
||||||
|
matrix = _median_filter(matrix, self.config.median_filter_width)
|
||||||
|
|
||||||
|
# Average the different cross-attention heads.
|
||||||
|
matrix = matrix.mean(dim=0)
|
||||||
|
else:
|
||||||
|
matrix = weights[batch_idx]
|
||||||
|
|
||||||
|
text_indices, time_indices = _dynamic_time_warping(-matrix.double().cpu().numpy())
|
||||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||||
jump_times = time_indices[jumps] * time_precision
|
jump_times = time_indices[jumps] * time_precision
|
||||||
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
timestamps[batch_idx, 1:] = torch.tensor(jump_times)
|
||||||
|
|||||||
@@ -559,7 +559,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
generate_kwargs["return_token_timestamps"] = True
|
generate_kwargs["return_token_timestamps"] = True
|
||||||
|
|
||||||
if stride is not None:
|
if stride is not None:
|
||||||
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
if isinstance(stride, tuple):
|
||||||
|
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
||||||
|
else:
|
||||||
|
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
||||||
|
|
||||||
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
||||||
generate_kwargs["input_features"] = inputs
|
generate_kwargs["input_features"] = inputs
|
||||||
|
|||||||
@@ -1850,6 +1850,35 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tiny_token_timestamp_batch_generation(self):
|
||||||
|
set_seed(0)
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||||
|
model.to(torch_device)
|
||||||
|
model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
||||||
|
|
||||||
|
num_samples = 4
|
||||||
|
num_return_sequences = 2
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(num_samples)
|
||||||
|
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="pt").input_features.to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_outputs = model.generate(
|
||||||
|
input_features,
|
||||||
|
max_length=448,
|
||||||
|
return_timestamps=True,
|
||||||
|
return_token_timestamps=True,
|
||||||
|
num_beams=3,
|
||||||
|
num_return_sequences=num_return_sequences,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)
|
||||||
|
|
||||||
|
self.assertEqual(len(generate_outputs.sequences), num_return_sequences * num_samples)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_tiny_specaugment_librispeech(self):
|
def test_tiny_specaugment_librispeech(self):
|
||||||
torch_device = "cpu"
|
torch_device = "cpu"
|
||||||
|
|||||||
@@ -674,6 +674,50 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_whisper_word_timestamps_batched(self):
|
||||||
|
pipe = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny",
|
||||||
|
chunk_length_s=3,
|
||||||
|
return_timestamps="word",
|
||||||
|
)
|
||||||
|
data = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
sample = data[0]["audio"]
|
||||||
|
|
||||||
|
# not the same output as test_simple_whisper_asr because of chunking
|
||||||
|
EXPECTED_OUTPUT = {
|
||||||
|
"text": " Mr. Quilder is the apostle of the middle classes and we are glad to welcome his gospel.",
|
||||||
|
"chunks": [
|
||||||
|
{"text": " Mr.", "timestamp": (0.48, 0.96)},
|
||||||
|
{"text": " Quilder", "timestamp": (0.96, 1.24)},
|
||||||
|
{"text": " is", "timestamp": (1.24, 1.5)},
|
||||||
|
{"text": " the", "timestamp": (1.5, 1.72)},
|
||||||
|
{"text": " apostle", "timestamp": (1.72, 1.98)},
|
||||||
|
{"text": " of", "timestamp": (1.98, 2.32)},
|
||||||
|
{"text": " the", "timestamp": (2.32, 2.5)},
|
||||||
|
{"text": " middle", "timestamp": (2.5, 2.68)},
|
||||||
|
{"text": " classes", "timestamp": (2.68, 3.2)},
|
||||||
|
{"text": " and", "timestamp": (3.2, 3.56)},
|
||||||
|
{"text": " we", "timestamp": (3.56, 3.68)},
|
||||||
|
{"text": " are", "timestamp": (3.68, 3.8)},
|
||||||
|
{"text": " glad", "timestamp": (3.8, 4.1)},
|
||||||
|
{"text": " to", "timestamp": (4.1, 4.34)},
|
||||||
|
{"text": " welcome", "timestamp": (4.3, 4.6)},
|
||||||
|
{"text": " his", "timestamp": (4.6, 4.94)},
|
||||||
|
{"text": " gospel.", "timestamp": (4.94, 5.82)},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# batch size 1: copy the audio sample since pipeline consumes it
|
||||||
|
output = pipe(sample.copy(), batch_size=1)
|
||||||
|
self.assertDictEqual(output, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
|
# batch size 2: input audio is chunked into smaller pieces so it's testing batching
|
||||||
|
output = pipe(sample, batch_size=2)
|
||||||
|
self.assertDictEqual(output, EXPECTED_OUTPUT)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_speech_encoder_decoder(self):
|
def test_torch_speech_encoder_decoder(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user