Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM (#28706)
* Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM * update with a type filter * add raises error test * fix added test
This commit is contained in:
@@ -25,7 +25,7 @@ import numpy as np
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import AutoProcessor
|
||||
from transformers import AutoFeatureExtractor, AutoProcessor
|
||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
|
||||
@@ -157,6 +157,35 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
def test_another_feature_extractor(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
||||
tokenizer = self.get_tokenizer()
|
||||
decoder = self.get_decoder()
|
||||
|
||||
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
|
||||
|
||||
raw_speech = floats_list((3, 1000))
|
||||
|
||||
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
||||
input_processor = processor(raw_speech, return_tensors="np")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
self.assertListEqual(
|
||||
processor.model_input_names,
|
||||
feature_extractor.model_input_names,
|
||||
msg="`processor` and `feature_extractor` model input names do not match",
|
||||
)
|
||||
|
||||
def test_wrong_feature_extractor_raises_error(self):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
|
||||
tokenizer = self.get_tokenizer()
|
||||
decoder = self.get_decoder()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
|
||||
|
||||
def test_tokenizer(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
Reference in New Issue
Block a user