From 3fefee99108de855f5659679c9d034a3be5ad0f4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 19 Jan 2022 21:04:26 +0100 Subject: [PATCH] Make chuking smartly (long files) work on asr ctc_with_lm. (#15219) * [WIP] Make chuking smartly (long files) work on asr ctc_with_lm. * Slow test with functionality. * Fixing regular test. * fix for batch size 1 * Handling batch outside `rescale_Stride`. - Renamed to `rescale_stride`. * Disable equality in the test. * Remove print. Co-authored-by: Patrick von Platen --- .../pipelines/automatic_speech_recognition.py | 65 ++++++++++++++++--- ..._pipelines_automatic_speech_recognition.py | 49 +++++++++++++- 2 files changed, 103 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f3bdb4277e..cd24ac33ee 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -66,14 +66,34 @@ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: return audio -def apply_stride(tokens, stride): - max_token_n = tokens.shape[-1] +def rescale_stride(tokens_or_logits, stride): + """ + Rescales the stride values from audio space to tokens/logits space. + + (160_000, 16_000, 16_000) -> (2000, 200, 200) for instance. + """ + # Shape is [B, SEQ] for tokens + # [B, SEQ, V] for logits + + max_token_n = tokens_or_logits.shape[1] max_input_n = max(input_n for input_n, _, _ in stride) ratio = max_token_n / max_input_n - for i, (input_n, left, right) in enumerate(stride): + new_strides = [] + for input_n, left, right in stride: token_n = int(round(input_n * ratio)) - left_token = int(round(left / input_n * token_n)) - right_token = int(round((input_n - right) / input_n * token_n)) + left = int(round(left / input_n * token_n)) + right = int(round(right / input_n * token_n)) + new_stride = (token_n, left, right) + new_strides.append(new_stride) + + return new_strides + + +def apply_stride(tokens, stride): + new_stride = rescale_stride(tokens, stride) + for i, (input_n, left, right) in enumerate(new_stride): + left_token = left + right_token = input_n - right # This is CTC to preseve decoding, we need to duplicate # next letter, and last letter @@ -215,7 +235,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate)) stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate)) - if self.type != "ctc": + if self.type not in {"ctc", "ctc_with_lm"}: raise ValueError( "`chunk_length_s` is only valid for CTC models, use other chunking options for other models" ) @@ -244,9 +264,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ) out = {"tokens": tokens} elif self.type == "ctc_with_lm": + stride = model_inputs.pop("stride", None) outputs = self.model(**model_inputs) - out = {"logits": outputs.logits} - + logits = outputs.logits + out = {"logits": logits} + if stride is not None: + # Send stride to `postprocess`. + # it needs to be handled there where + # the pieces are to be concatenated. + if isinstance(stride, tuple): + out["stride"] = rescale_stride(logits, [stride])[0] + else: + out["stride"] = rescale_stride(logits, stride) elif self.type == "ctc": stride = model_inputs.pop("stride", None) outputs = self.model(**model_inputs) @@ -266,7 +295,25 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def postprocess(self, model_outputs): if self.type == "ctc_with_lm": - logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1) + final_logits = [] + for outputs in model_outputs: + logits = outputs["logits"].numpy() + stride = outputs.get("stride", None) + if stride is not None: + try: + total_n, left, right = stride + except Exception: + import ipdb + + ipdb.set_trace() + # Total_n might be < logits.shape[1] + # because of padding, that's why + # we need to reconstruct this information + # This won't work with left padding (which doesn't exist right now) + right_n = total_n - right + logits = logits[:, left:right_n] + final_logits.append(logits) + logits = np.concatenate(final_logits, axis=1) logits = logits.squeeze(0) text = self.decoder.decode_beams(logits)[0][0] else: diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index c64f6b69dc..fd3d54c4c4 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -295,13 +295,39 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output[0]["text"][:6], "ZBT ZC") + @require_torch + @require_pyctcdecode + def test_chunking_fast_with_lm(self): + speech_recognizer = pipeline( + model="hf-internal-testing/processor_with_lm", + chunk_length_s=10.0, + ) + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + audio = ds[40]["audio"]["array"] + + n_repeats = 2 + audio_tiled = np.tile(audio, n_repeats) + # Batch_size = 1 + output1 = speech_recognizer([audio_tiled], batch_size=1) + self.assertEqual(output1, [{"text": ANY(str)}]) + self.assertEqual(output1[0]["text"][:6], "