[AutoProcessor] Add Wav2Vec2WithLM & small fix (#14675)
* [AutoProcessor] Add Wav2Vec2WithLM & small fix * revert line removal * Update src/transformers/__init__.py * add test * up * up * small fix
This commit is contained in:
committed by
GitHub
parent
2294071a0c
commit
ee4fa2e465
@@ -1,3 +1,4 @@
|
||||
{
|
||||
"feature_extractor_type": "Wav2Vec2FeatureExtractor"
|
||||
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
||||
"processor_class": "Wav2Vec2Processor"
|
||||
}
|
||||
|
||||
@@ -16,15 +16,16 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from shutil import copyfile
|
||||
|
||||
from transformers import AutoProcessor, Wav2Vec2Config, Wav2Vec2Processor
|
||||
from transformers.file_utils import FEATURE_EXTRACTOR_NAME
|
||||
|
||||
|
||||
SAMPLE_PROCESSOR_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures")
|
||||
SAMPLE_PROCESSOR_CONFIG = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy_feature_extractor_config.json"
|
||||
)
|
||||
SAMPLE_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")
|
||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/vocab.json")
|
||||
|
||||
|
||||
class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
@@ -32,7 +33,7 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
def test_processor_from_local_directory_from_config(self):
|
||||
def test_processor_from_local_directory_from_repo(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model_config = Wav2Vec2Config()
|
||||
processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
@@ -44,3 +45,13 @@ class AutoFeatureExtractorTest(unittest.TestCase):
|
||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
def test_processor_from_local_directory_from_extractor_config(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# copy relevant files
|
||||
copyfile(SAMPLE_PROCESSOR_CONFIG, os.path.join(tmpdirname, FEATURE_EXTRACTOR_NAME))
|
||||
copyfile(SAMPLE_VOCAB, os.path.join(tmpdirname, "vocab.json"))
|
||||
|
||||
processor = AutoProcessor.from_pretrained(tmpdirname)
|
||||
|
||||
self.assertIsInstance(processor, Wav2Vec2Processor)
|
||||
|
||||
@@ -31,7 +31,7 @@ from .test_feature_extraction_wav2vec2 import floats_list
|
||||
|
||||
if is_pyctcdecode_available():
|
||||
from pyctcdecode import BeamSearchDecoderCTC
|
||||
from transformers.models.wav2vec2 import Wav2Vec2ProcessorWithLM
|
||||
from transformers.models.wav2vec2_with_lm import Wav2Vec2ProcessorWithLM
|
||||
|
||||
|
||||
@require_pyctcdecode
|
||||
|
||||
Reference in New Issue
Block a user