Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590)

* Fix loading pipelines with wav2vec models with lm when in local paths

* Adding tests

* Fix test

* Adding tests

* Flake8 fixes

* Removing conflict files :(

* Adding task type to test

* Remove unnecessary test and imports
This commit is contained in:
Javier de la Rosa
2022-02-15 13:45:08 +01:00
committed by GitHub
parent e1cbc073bf
commit 9eb7e9ba1d
4 changed files with 51 additions and 6 deletions

View File

@@ -18,6 +18,7 @@ import numpy as np
import pytest
from datasets import load_dataset
from huggingface_hub import snapshot_download
from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
@@ -368,6 +369,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@require_pyctcdecode
def test_with_local_lm_fast(self):
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model=local_dir,
)
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
audio = ds[40]["audio"]["array"]
n_repeats = 2
audio_tiled = np.tile(audio, n_repeats)
output = speech_recognizer([audio_tiled], batch_size=2)
self.assertEqual(output, [{"text": ANY(str)}])
self.assertEqual(output[0]["text"][:6], "<s> <s")
@require_torch
@slow
def test_chunking(self):

View File

@@ -31,6 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available():
from huggingface_hub import snapshot_download
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
@@ -303,3 +304,20 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self.assertListEqual(downloaded_decoder_files, expected_decoder_files)
def test_decoder_local_files(self):
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
processor = Wav2Vec2ProcessorWithLM.from_pretrained(local_dir)
language_model = processor.decoder.model_container[processor.decoder._model_key]
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
local_decoder_files = os.listdir(local_dir)
expected_decoder_files = os.listdir(path_to_cached_dir)
local_decoder_files.sort()
expected_decoder_files.sort()
# test that both decoder form hub and local files in cache are the same
self.assertListEqual(local_decoder_files, expected_decoder_files)