Improve error messaging for ASR pipeline. (#19570)

* Improve error messaging for ASR pipeline.

- Raise error early (in `_sanitize`) so users don't waste time trying to
  run queries with invalid params.

- Fix the error was after using `config.inputs_to_logits_ratio` so our
  check was masked by the failing property does not exist.

- Added some manual check on s2t for the error message.
  No non ctc model seems to be used by the default runner (they are all
  skipped).

* Removing pdb.

* Stop the early error it doesn't really work :(.
This commit is contained in:
Nicolas Patry
2022-10-14 17:12:21 +02:00
committed by GitHub
parent 5ef2186692
commit 463226e2ee
2 changed files with 22 additions and 5 deletions

View File

@@ -250,6 +250,10 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline") raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
if chunk_length_s: if chunk_length_s:
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if stride_length_s is None: if stride_length_s is None:
stride_length_s = chunk_length_s / 6 stride_length_s = chunk_length_s / 6
@@ -264,10 +268,6 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to) stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models"
)
if chunk_len < stride_left + stride_right: if chunk_len < stride_left + stride_right:
raise ValueError("Chunk length must be superior to stride length") raise ValueError("Chunk length must be superior to stride length")

View File

@@ -118,9 +118,15 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
}, },
) )
else: else:
# Non CTC models cannot use chunk_length
with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, chunk_length_s=10)
self.assertEqual(v.exception, "")
# Non CTC models cannot use return_timestamps # Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError): with self.assertRaises(ValueError) as v:
outputs = speech_recognizer(audio, return_timestamps="char") outputs = speech_recognizer(audio, return_timestamps="char")
self.assertEqual(v.exception, "")
@require_torch @require_torch
@slow @slow
@@ -138,6 +144,17 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
waveform = np.tile(np.arange(1000, dtype=np.float32), 34) waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform) output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"}) self.assertEqual(output, {"text": "(Applaudissements)"})
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(
str(v.exception),
"`chunk_length_s` is only valid for CTC models, use other chunking options for other models",
)
# Non CTC models cannot use return_timestamps
with self.assertRaises(ValueError) as v:
_ = speech_recognizer(waveform, return_timestamps="char")
self.assertEqual(str(v.exception), "We cannot return_timestamps yet on non-ctc models !")
@require_torch @require_torch
def test_small_model_pt_seq2seq(self): def test_small_model_pt_seq2seq(self):