From 463226e2ee372ae48f473cd9f93917839f0901ff Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 14 Oct 2022 17:12:21 +0200 Subject: [PATCH] Improve error messaging for ASR pipeline. (#19570) * Improve error messaging for ASR pipeline. - Raise error early (in `_sanitize`) so users don't waste time trying to run queries with invalid params. - Fix the error was after using `config.inputs_to_logits_ratio` so our check was masked by the failing property does not exist. - Added some manual check on s2t for the error message. No non ctc model seems to be used by the default runner (they are all skipped). * Removing pdb. * Stop the early error it doesn't really work :(. --- .../pipelines/automatic_speech_recognition.py | 8 ++++---- ..._pipelines_automatic_speech_recognition.py | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 1d5546edb5..c6bf4c0958 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -250,6 +250,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") if chunk_length_s: + if self.type not in {"ctc", "ctc_with_lm"}: + raise ValueError( + "`chunk_length_s` is only valid for CTC models, use other chunking options for other models" + ) if stride_length_s is None: stride_length_s = chunk_length_s / 6 @@ -264,10 +268,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) - if self.type not in {"ctc", "ctc_with_lm"}: - raise ValueError( - "`chunk_length_s` is only valid for CTC models, use other chunking options for other models" - ) if chunk_len < stride_left + stride_right: raise ValueError("Chunk length must be superior to stride length") diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index f73fda39e9..2aacbccbd4 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel }, ) else: + # Non CTC models cannot use chunk_length + with self.assertRaises(ValueError) as v: + outputs = speech_recognizer(audio, chunk_length_s=10) + self.assertEqual(v.exception, "") + # Non CTC models cannot use return_timestamps - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as v: outputs = speech_recognizer(audio, return_timestamps="char") + self.assertEqual(v.exception, "") @require_torch @slow @@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel waveform = np.tile(np.arange(1000, dtype=np.float32), 34) output = speech_recognizer(waveform) self.assertEqual(output, {"text": "(Applaudissements)"}) + with self.assertRaises(ValueError) as v: + _ = speech_recognizer(waveform, chunk_length_s=10) + self.assertEqual( + str(v.exception), + "`chunk_length_s` is only valid for CTC models, use other chunking options for other models", + ) + + # Non CTC models cannot use return_timestamps + with self.assertRaises(ValueError) as v: + _ = speech_recognizer(waveform, return_timestamps="char") + self.assertEqual(str(v.exception), "We cannot return_timestamps yet on non-ctc models !") @require_torch def test_small_model_pt_seq2seq(self):