Add language to word timestamps for Whisper (#31572)
* add language to words _collate_word_timestamps uses the return_language flag to determine whether the language of the chunk should be added to the word's information * ran style checks added missing comma * add new language test test that the pipeline can return both the language and timestamp * remove model configuration in test Removed model configurations that do not influence test results * remove model configuration in test Removed model configurations that do not influence test results
This commit is contained in:
@@ -1033,7 +1033,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
chunk["text"] = resolved_text
|
chunk["text"] = resolved_text
|
||||||
if return_timestamps == "word":
|
if return_timestamps == "word":
|
||||||
chunk["words"] = _collate_word_timestamps(
|
chunk["words"] = _collate_word_timestamps(
|
||||||
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
|
tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language
|
||||||
)
|
)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
@@ -1085,7 +1085,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language,
|
|||||||
chunk["text"] = resolved_text
|
chunk["text"] = resolved_text
|
||||||
if return_timestamps == "word":
|
if return_timestamps == "word":
|
||||||
chunk["words"] = _collate_word_timestamps(
|
chunk["words"] = _collate_word_timestamps(
|
||||||
tokenizer, resolved_tokens, resolved_token_timestamps, last_language
|
tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language
|
||||||
)
|
)
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
@@ -1217,12 +1217,16 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
|
|||||||
return total_sequence, []
|
return total_sequence, []
|
||||||
|
|
||||||
|
|
||||||
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language):
|
def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language):
|
||||||
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
|
words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
|
||||||
|
|
||||||
|
optional_language_field = {"language": language} if return_language else {}
|
||||||
|
|
||||||
timings = [
|
timings = [
|
||||||
{
|
{
|
||||||
"text": word,
|
"text": word,
|
||||||
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
|
"timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
|
||||||
|
**optional_language_field,
|
||||||
}
|
}
|
||||||
for word, indices in zip(words, token_indices)
|
for word, indices in zip(words, token_indices)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -322,7 +322,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
|
||||||
def test_return_timestamps_in_preprocess(self):
|
def test_return_timestamps_in_preprocess(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task="automatic-speech-recognition",
|
task="automatic-speech-recognition",
|
||||||
@@ -332,10 +331,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
|
data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
|
||||||
sample = next(iter(data))
|
sample = next(iter(data))
|
||||||
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe")
|
|
||||||
|
|
||||||
res = pipe(sample["audio"]["array"])
|
res = pipe(sample["audio"]["array"])
|
||||||
self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."})
|
self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."})
|
||||||
|
|
||||||
res = pipe(sample["audio"]["array"], return_timestamps=True)
|
res = pipe(sample["audio"]["array"], return_timestamps=True)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
res,
|
||||||
@@ -344,9 +343,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
"chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}],
|
"chunks": [{"timestamp": (0.0, 3.36), "text": " Conquered returned to its place amidst the tents."}],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
pipe.model.generation_config.alignment_heads = [[2, 2], [3, 0], [3, 2], [3, 3], [3, 4], [3, 5]]
|
|
||||||
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
|
||||||
|
|
||||||
|
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||||
# fmt: off
|
# fmt: off
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
res,
|
||||||
@@ -366,6 +364,63 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_return_timestamps_and_language_in_preprocess(self):
|
||||||
|
pipe = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny",
|
||||||
|
chunk_length_s=8,
|
||||||
|
stride_length_s=1,
|
||||||
|
return_language=True,
|
||||||
|
)
|
||||||
|
data = load_dataset("openslr/librispeech_asr", "clean", split="test", streaming=True, trust_remote_code=True)
|
||||||
|
sample = next(iter(data))
|
||||||
|
|
||||||
|
res = pipe(sample["audio"]["array"])
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
|
"chunks": [{"language": "english", "text": " Conquered returned to its place amidst the tents."}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
res = pipe(sample["audio"]["array"], return_timestamps=True)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
|
"chunks": [
|
||||||
|
{
|
||||||
|
"timestamp": (0.0, 3.36),
|
||||||
|
"language": "english",
|
||||||
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
res = pipe(sample["audio"]["array"], return_timestamps="word")
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
'text': ' Conquered returned to its place amidst the tents.',
|
||||||
|
'chunks': [
|
||||||
|
{"language": "english",'text': ' Conquered', 'timestamp': (0.5, 1.2)},
|
||||||
|
{"language": "english", 'text': ' returned', 'timestamp': (1.2, 1.64)},
|
||||||
|
{"language": "english",'text': ' to', 'timestamp': (1.64, 1.84)},
|
||||||
|
{"language": "english",'text': ' its', 'timestamp': (1.84, 2.02)},
|
||||||
|
{"language": "english",'text': ' place', 'timestamp': (2.02, 2.28)},
|
||||||
|
{"language": "english",'text': ' amidst', 'timestamp': (2.28, 2.8)},
|
||||||
|
{"language": "english",'text': ' the', 'timestamp': (2.8, 2.98)},
|
||||||
|
{"language": "english",'text': ' tents.', 'timestamp': (2.98, 3.48)},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_return_timestamps_in_preprocess_longform(self):
|
def test_return_timestamps_in_preprocess_longform(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user