[AutoProcessor] Add Wav2Vec2WithLM & small fix (#14675)

* [AutoProcessor] Add Wav2Vec2WithLM & small fix

* revert line removal

* Update src/transformers/__init__.py

* add test

* up

* up

* small fix
This commit is contained in:
Patrick von Platen
2021-12-08 15:51:28 +01:00
committed by GitHub
parent 2294071a0c
commit ee4fa2e465
10 changed files with 72 additions and 16 deletions

View File

@@ -1,3 +1,4 @@
{
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"processor_class": "Wav2Vec2Processor"
}

View File

@@ -16,15 +16,16 @@
import os
import tempfile
import unittest
from shutil import copyfile
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
SAMPLE_PROCESSOR_CONFIG = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
)
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
class AutoFeatureExtractorTest(unittest.TestCase):
@@ -32,7 +33,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
self.assertIsInstance(processor, Wav2Vec2Processor)
def test_processor_from_local_directory_from_config(self):
def test_processor_from_local_directory_from_repo(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model_config = Wav2Vec2Config()
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
@@ -44,3 +45,13 @@ class AutoFeatureExtractorTest(unittest.TestCase):
processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor)
def test_processor_from_local_directory_from_extractor_config(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# copy relevant files
copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
processor = AutoProcessor.from_pretrained(tmpdirname)
self.assertIsInstance(processor, Wav2Vec2Processor)

View File

@@ -31,7 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
if is_pyctcdecode_available():
from pyctcdecode import BeamSearchDecoderCTC
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
@require_pyctcdecode