From 9eb7e9ba1d132eec947e95988f90ddc41e3bb65d Mon Sep 17 00:00:00 2001 From: Javier de la Rosa Date: Tue, 15 Feb 2022 13:45:08 +0100 Subject: [PATCH] Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590) * Fix loading pipelines with wav2vec models with lm when in local paths * Adding tests * Fix test * Adding tests * Flake8 fixes * Removing conflict files :( * Adding task type to test * Remove unnecessary test and imports --- .../processing_wav2vec2_with_lm.py | 2 +- src/transformers/pipelines/__init__.py | 15 ++++++++----- ..._pipelines_automatic_speech_recognition.py | 22 +++++++++++++++++++ tests/test_processor_wav2vec2_with_lm.py | 18 +++++++++++++++ 4 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py index c31b209c18..4947ce39ae 100644 --- a/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +++ b/src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py @@ -127,7 +127,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin): feature_extractor, tokenizer = super()._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs) - if os.path.isdir(pretrained_model_name_or_path): + if os.path.isdir(pretrained_model_name_or_path) or os.path.isfile(pretrained_model_name_or_path): decoder = BeamSearchDecoderCTC.load_from_dir(pretrained_model_name_or_path) else: # BeamSearchDecoderCTC has no auto class diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index fab5ccb008..4b5f2c1663 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -621,15 +621,20 @@ def pipeline( import kenlm # to trigger `ImportError` if not installed from pyctcdecode import BeamSearchDecoderCTC - language_model_glob = os.path.join(BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*") - alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME - allow_regex = [language_model_glob, alphabet_filename] + if os.path.isdir(model_name) or os.path.isfile(model_name): + decoder = BeamSearchDecoderCTC.load_from_dir(model_name) + else: + language_model_glob = os.path.join( + BeamSearchDecoderCTC._LANGUAGE_MODEL_SERIALIZED_DIRECTORY, "*" + ) + alphabet_filename = BeamSearchDecoderCTC._ALPHABET_SERIALIZED_FILENAME + allow_regex = [language_model_glob, alphabet_filename] + decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex) - decoder = BeamSearchDecoderCTC.load_from_hf_hub(model_name, allow_regex=allow_regex) kwargs["decoder"] = decoder except ImportError as e: logger.warning( - "Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}" + f"Could not load the `decoder` for {model_name}. Defaulting to raw CTC. Try to install `pyctcdecode` and `kenlm`: (`pip install pyctcdecode`, `pip install https://github.com/kpu/kenlm/archive/master.zip`): Error: {e}" ) if task == "translation" and model.config.task_specific_params: diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index 15b5f72612..37ab808e77 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -18,6 +18,7 @@ import numpy as np import pytest from datasets import load_dataset +from huggingface_hub import snapshot_download from transformers import ( MODEL_FOR_CTC_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, @@ -368,6 +369,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel self.assertEqual(output, [{"text": ANY(str)}]) self.assertEqual(output[0]["text"][:6], "