Improve error messaging for ASR pipeline. (#19570)

* Improve error messaging for ASR pipeline.

- Raise error early (in `_sanitize`) so users don't waste time trying to
  run queries with invalid params.

- Fix the error was after using `config.inputs_to_logits_ratio` so our
  check was masked by the failing property does not exist.

- Added some manual check on s2t for the error message.
  No non ctc model seems to be used by the default runner (they are all
  skipped).

* Removing pdb.

* Stop the early error it doesn't really work :(.
This commit is contained in:
Nicolas Patry
2022-10-14 17:12:21 +02:00
committed by GitHub
parent 5ef2186692
commit 463226e2ee
2 changed files with 22 additions and 5 deletions

View File

@@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
},
)
else:
# Non CTC models cannot use chunk_length
with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, chunk_length_s=10)
self.assertEqual(v.exception, "")
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError):
with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(v.exception, "")
@require_torch
@slow
@@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, return_timestamps="char")
self.assertEqual(str(v.exception), "We cannot return_timestamps yet on non-ctc models !")
@require_torch
def test_small_model_pt_seq2seq(self):