[Whispe] Fix pipeline after timestamp merges (#21198)
* pass return_timestamps to pre-process * add a test to test it * test does not need device 0 * remove failing bit * update test
This commit is contained in:
@@ -400,6 +400,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
" only 1 version"
|
" only 1 version"
|
||||||
)
|
)
|
||||||
forward_params["generate_kwargs"].update(generate_kwargs)
|
forward_params["generate_kwargs"].update(generate_kwargs)
|
||||||
|
if return_timestamps is not None:
|
||||||
|
forward_params["generate_kwargs"]["return_timestamps"] = return_timestamps
|
||||||
|
|
||||||
postprocess_params = {}
|
postprocess_params = {}
|
||||||
if decoder_kwargs is not None:
|
if decoder_kwargs is not None:
|
||||||
@@ -523,6 +525,8 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
generate_kwargs = {}
|
generate_kwargs = {}
|
||||||
|
|
||||||
is_last = model_inputs.pop("is_last")
|
is_last = model_inputs.pop("is_last")
|
||||||
|
return_timestamps = generate_kwargs.pop("return_timestamps", False)
|
||||||
|
|
||||||
if self.type == "seq2seq":
|
if self.type == "seq2seq":
|
||||||
encoder = self.model.get_encoder()
|
encoder = self.model.get_encoder()
|
||||||
# Consume values so we can let extra information flow freely through
|
# Consume values so we can let extra information flow freely through
|
||||||
@@ -552,7 +556,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
|
|||||||
stride = model_inputs.pop("stride", None)
|
stride = model_inputs.pop("stride", None)
|
||||||
tokens = self.model.generate(
|
tokens = self.model.generate(
|
||||||
input_features=model_inputs.pop("input_features"),
|
input_features=model_inputs.pop("input_features"),
|
||||||
logits_processor=[WhisperTimeStampLogitsProcessor()],
|
logits_processor=[WhisperTimeStampLogitsProcessor()] if return_timestamps else None,
|
||||||
**generate_kwargs,
|
**generate_kwargs,
|
||||||
)
|
)
|
||||||
out = {"tokens": tokens}
|
out = {"tokens": tokens}
|
||||||
|
|||||||
@@ -291,6 +291,29 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer(filename)
|
output = speech_recognizer(filename)
|
||||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
def test_return_timestamps_in_preprocess(self):
|
||||||
|
pipe = pipeline(
|
||||||
|
task="automatic-speech-recognition",
|
||||||
|
model="openai/whisper-tiny",
|
||||||
|
chunk_length_s=8,
|
||||||
|
stride_length_s=1,
|
||||||
|
)
|
||||||
|
data = load_dataset("librispeech_asr", "clean", split="test", streaming=True)
|
||||||
|
sample = next(iter(data))
|
||||||
|
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||||
|
|
||||||
|
res = pipe(sample["audio"]["array"])
|
||||||
|
self.assertEqual(res, {"text": " Conquered returned to its place amidst the tents."})
|
||||||
|
res = pipe(sample["audio"]["array"], return_timestamps=True)
|
||||||
|
self.assertEqual(
|
||||||
|
res,
|
||||||
|
{
|
||||||
|
"text": " Conquered returned to its place amidst the tents.",
|
||||||
|
"chunks": [{"text": " Conquered returned to its place amidst the tents.", "timestamp": (0.0, 3.36)}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_whisper(self):
|
def test_torch_whisper(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user