Fix ASR pipelines from local directories with wav2vec models that have language models attached (#15590)
* Fix loading pipelines with wav2vec models with lm when in local paths * Adding tests * Fix test * Adding tests * Flake8 fixes * Removing conflict files :( * Adding task type to test * Remove unnecessary test and imports
This commit is contained in:
committed by
GitHub
parent
e1cbc073bf
commit
9eb7e9ba1d
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
MODEL_FOR_CTC_MAPPING,
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
@@ -368,6 +369,27 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@require_pyctcdecode
|
||||
def test_with_local_lm_fast(self):
|
||||
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model=local_dir,
|
||||
)
|
||||
self.assertEqual(speech_recognizer.type, "ctc_with_lm")
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
audio = ds[40]["audio"]["array"]
|
||||
|
||||
n_repeats = 2
|
||||
audio_tiled = np.tile(audio, n_repeats)
|
||||
|
||||
output = speech_recognizer([audio_tiled], batch_size=2)
|
||||
|
||||
self.assertEqual(output, [{"text": ANY(str)}])
|
||||
self.assertEqual(output[0]["text"][:6], "<s> <s")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_chunking(self):
|
||||
|
||||
@@ -31,6 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from huggingface_hub import snapshot_download
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
|
||||
@@ -303,3 +304,20 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
|
||||
# are downloaded and none of the rest (e.g. README.md, ...)
|
||||
self.assertListEqual(downloaded_decoder_files, expected_decoder_files)
|
||||
|
||||
def test_decoder_local_files(self):
|
||||
local_dir = snapshot_download("hf-internal-testing/processor_with_lm")
|
||||
|
||||
processor = Wav2Vec2ProcessorWithLM.from_pretrained(local_dir)
|
||||
|
||||
language_model = processor.decoder.model_container[processor.decoder._model_key]
|
||||
path_to_cached_dir = Path(language_model._kenlm_model.path.decode("utf-8")).parent.parent.absolute()
|
||||
|
||||
local_decoder_files = os.listdir(local_dir)
|
||||
expected_decoder_files = os.listdir(path_to_cached_dir)
|
||||
|
||||
local_decoder_files.sort()
|
||||
expected_decoder_files.sort()
|
||||
|
||||
# test that both decoder form hub and local files in cache are the same
|
||||
self.assertListEqual(local_decoder_files, expected_decoder_files)
|
||||
|
||||
Reference in New Issue
Block a user