From 68cc4ccde25b1ed8a9b77fdf6f78c833bdff0e9c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 12 Jan 2022 09:28:19 +0100 Subject: [PATCH] Pipeline ASR with LM. (#15071) * Pipeline ASR with LM. * Revamped into `self.decoder`. * Fixing. * 2nd fix. * Update src/transformers/pipelines/__init__.py Co-authored-by: Patrick von Platen * Fixing. Co-authored-by: Patrick von Platen --- src/transformers/pipelines/__init__.py | 21 +++++++++ .../pipelines/automatic_speech_recognition.py | 43 ++++++++++++++----- ..._pipelines_automatic_speech_recognition.py | 32 ++++++++++++++ 3 files changed, 85 insertions(+), 11 deletions(-) diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index e65068e534..77259462f9 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -611,6 +611,27 @@ def pipeline( feature_extractor, revision=revision, _from_pipeline=task, **model_kwargs ) + if ( + feature_extractor._processor_class + and feature_extractor._processor_class.endswith("WithLM") + and isinstance(model_name, str) + ): + try: + from pyctcdecode import BeamSearchDecoderCTC + + language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*") + alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME + allow_regex = [language_model_glob, alphabet_filename] + + decoder = BeamSearchDecoderCTC.load_from_hf_hub( + pretrained_model_name_or_path, allow_regex=allow_regex + ) + kwargs["decoder"] = decoder + except Exception as e: + logger.warning( + "Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}" + ) + if task == "translation" and model.config.task_specific_params: for key in model.config.task_specific_params: if key.startswith("translation"): diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index e4745bcd81..d059e5f407 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -144,7 +144,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): raise ValueError("The AutomaticSpeechRecognitionPipeline is only available in PyTorch.") self.check_model_type(dict(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items() + MODEL_FOR_CTC_MAPPING.items())) - self.is_ctc = self.model.__class__ in MODEL_FOR_CTC_MAPPING.values() + + if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): + self.type = "seq2seq" + elif ( + self.feature_extractor._processor_class + and self.feature_extractor._processor_class.endswith("WithLM") + and kwargs.get("decoder", None) is not None + ): + self.decoder = kwargs["decoder"] + self.type = "ctc_with_lm" + else: + self.type = "ctc" def __call__( self, @@ -222,8 +233,7 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): def _forward(self, model_inputs): is_last = model_inputs.pop("is_last") - model_class = self.model.__class__ - if model_class in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values(): + if self.type == "seq2seq": encoder = self.model.get_encoder() # we need to pass `processed.get("attention_mask")` here since audio encoder # attention mask length is different from expected text decoder `encoder_attention_mask` length @@ -232,7 +242,12 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): tokens = self.model.generate( encoder_outputs=encoder(**model_inputs), attention_mask=model_inputs.get("attention_mask") ) - elif model_class in MODEL_FOR_CTC_MAPPING.values(): + out = {"tokens": tokens} + elif self.type == "ctc_with_lm": + outputs = self.model(**model_inputs) + out = {"logits": outputs.logits} + + elif self.type == "ctc": stride = model_inputs.pop("stride", None) outputs = self.model(**model_inputs) tokens = outputs.logits.argmax(dim=-1) @@ -241,16 +256,22 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): stride = [stride] apply_stride(tokens, stride) + out = {"tokens": tokens} else: logger.warning("This is an unknown class, treating it as CTC.") outputs = self.model(**model_inputs) tokens = outputs.logits.argmax(dim=-1) - return {"tokens": tokens, "is_last": is_last} + out = {"tokens": tokens} + return {"is_last": is_last, **out} def postprocess(self, model_outputs): - skip_special_tokens = False if "CTC" in self.tokenizer.__class__.__name__ else True - tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1) - tokens = tokens.squeeze(0) - - recognized_string = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) - return {"text": recognized_string} + if self.type == "ctc_with_lm": + logits = np.concatenate([outputs["logits"].numpy() for outputs in model_outputs], axis=1) + logits = logits.squeeze(0) + text = self.decoder.decode_beams(logits)[0][0] + else: + skip_special_tokens = self.type != "ctc" + tokens = np.concatenate([outputs["tokens"].numpy() for outputs in model_outputs], axis=-1) + tokens = tokens.squeeze(0) + text = self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) + return {"text": text} diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index f951b8a90f..c77a5c56fc 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -32,6 +32,7 @@ from transformers.testing_utils import ( is_pipeline_test, is_torch_available, nested_simplify, + require_pyctcdecode, require_tf, require_torch, require_torchaudio, @@ -97,6 +98,37 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel output = speech_recognizer(waveform) self.assertEqual(output, {"text": "(Applaudissements)"}) + @slow + @require_torch + @require_pyctcdecode + def test_large_model_pt_with_lm(self): + dataset = load_dataset("Narsil/asr_dummy") + filename = dataset["test"][3]["file"] + + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm", + framework="pt", + ) + self.assertEqual(speech_recognizer.type, "ctc_with_lm") + + output = speech_recognizer(filename) + self.assertEqual( + output, + {"text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumaje"}, + ) + + # Override back to pure CTC + speech_recognizer.type = "ctc" + output = speech_recognizer(filename) + # plumajre != plumaje + self.assertEqual( + output, + { + "text": "y en las ramas medio sumergidas revoloteaban algunos pájaros de quimérico y legendario plumajre" + }, + ) + @require_tf def test_small_model_tf(self): self.skipTest("Tensorflow not supported yet.")