Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM (#28706)
* Add AutoFeatureExtractor support to Wav2Vec2ProcessorWithLM * update with a type filter * add raises error test * fix added test
This commit is contained in:
@@ -70,15 +70,15 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
with language model support into a single processor for language model boosted speech recognition decoding.
|
with language model support into a single processor for language model boosted speech recognition decoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
feature_extractor ([`Wav2Vec2FeatureExtractor`]):
|
feature_extractor ([`Wav2Vec2FeatureExtractor`] or [`SeamlessM4TFeatureExtractor`]):
|
||||||
An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is a required input.
|
An instance of [`Wav2Vec2FeatureExtractor`] or [`SeamlessM4TFeatureExtractor`]. The feature extractor is a required input.
|
||||||
tokenizer ([`Wav2Vec2CTCTokenizer`]):
|
tokenizer ([`Wav2Vec2CTCTokenizer`]):
|
||||||
An instance of [`Wav2Vec2CTCTokenizer`]. The tokenizer is a required input.
|
An instance of [`Wav2Vec2CTCTokenizer`]. The tokenizer is a required input.
|
||||||
decoder (`pyctcdecode.BeamSearchDecoderCTC`):
|
decoder (`pyctcdecode.BeamSearchDecoderCTC`):
|
||||||
An instance of [`pyctcdecode.BeamSearchDecoderCTC`]. The decoder is a required input.
|
An instance of [`pyctcdecode.BeamSearchDecoderCTC`]. The decoder is a required input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
feature_extractor_class = "Wav2Vec2FeatureExtractor"
|
feature_extractor_class = "AutoFeatureExtractor"
|
||||||
tokenizer_class = "Wav2Vec2CTCTokenizer"
|
tokenizer_class = "Wav2Vec2CTCTokenizer"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -93,6 +93,11 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
if not isinstance(decoder, BeamSearchDecoderCTC):
|
if not isinstance(decoder, BeamSearchDecoderCTC):
|
||||||
raise ValueError(f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__}, but is {type(decoder)}")
|
raise ValueError(f"`decoder` has to be of type {BeamSearchDecoderCTC.__class__}, but is {type(decoder)}")
|
||||||
|
|
||||||
|
if feature_extractor.__class__.__name__ not in ["Wav2Vec2FeatureExtractor", "SeamlessM4TFeatureExtractor"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`feature_extractor` has to be of type `Wav2Vec2FeatureExtractor` or `SeamlessM4TFeatureExtractor`, but is {type(feature_extractor)}"
|
||||||
|
)
|
||||||
|
|
||||||
# make sure that decoder's alphabet and tokenizer's vocab match in content
|
# make sure that decoder's alphabet and tokenizer's vocab match in content
|
||||||
missing_decoder_tokens = self.get_missing_alphabet_tokens(decoder, tokenizer)
|
missing_decoder_tokens = self.get_missing_alphabet_tokens(decoder, tokenizer)
|
||||||
if len(missing_decoder_tokens) > 0:
|
if len(missing_decoder_tokens) > 0:
|
||||||
@@ -117,7 +122,7 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
This class method is simply calling Wav2Vec2FeatureExtractor's
|
This class method is simply calling the feature extractor's
|
||||||
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
|
[`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], Wav2Vec2CTCTokenizer's
|
||||||
[`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
|
[`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], and
|
||||||
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
|
[`pyctcdecode.BeamSearchDecoderCTC.load_from_hf_hub`].
|
||||||
@@ -213,8 +218,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
When used in normal mode, this method forwards all its arguments to the feature extractor's
|
||||||
[`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context
|
[`~FeatureExtractionMixin.__call__`] and returns its output. If used in the context
|
||||||
[`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to
|
[`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to
|
||||||
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two
|
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.__call__`]. Please refer to the docstring of the above two
|
||||||
methods for more information.
|
methods for more information.
|
||||||
@@ -252,8 +257,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
|
|||||||
|
|
||||||
def pad(self, *args, **kwargs):
|
def pad(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
When used in normal mode, this method forwards all its arguments to the feature extractor's
|
||||||
[`~Wav2Vec2FeatureExtractor.pad`] and returns its output. If used in the context
|
[`~FeatureExtractionMixin.pad`] and returns its output. If used in the context
|
||||||
[`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to
|
[`~Wav2Vec2ProcessorWithLM.as_target_processor`] this method forwards all its arguments to
|
||||||
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods
|
Wav2Vec2CTCTokenizer's [`~Wav2Vec2CTCTokenizer.pad`]. Please refer to the docstring of the above two methods
|
||||||
for more information.
|
for more information.
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import numpy as np
|
|||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoFeatureExtractor, AutoProcessor
|
||||||
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
from transformers.models.wav2vec2 import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor
|
||||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||||
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
|
from transformers.testing_utils import require_pyctcdecode, require_torch, require_torchaudio, slow
|
||||||
@@ -157,6 +157,35 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
|||||||
for key in input_feat_extract.keys():
|
for key in input_feat_extract.keys():
|
||||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
|
def test_another_feature_extractor(self):
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
decoder = self.get_decoder()
|
||||||
|
|
||||||
|
processor = Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
|
||||||
|
|
||||||
|
raw_speech = floats_list((3, 1000))
|
||||||
|
|
||||||
|
input_feat_extract = feature_extractor(raw_speech, return_tensors="np")
|
||||||
|
input_processor = processor(raw_speech, return_tensors="np")
|
||||||
|
|
||||||
|
for key in input_feat_extract.keys():
|
||||||
|
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||||
|
|
||||||
|
self.assertListEqual(
|
||||||
|
processor.model_input_names,
|
||||||
|
feature_extractor.model_input_names,
|
||||||
|
msg="`processor` and `feature_extractor` model input names do not match",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_wrong_feature_extractor_raises_error(self):
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large-v3")
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
decoder = self.get_decoder()
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Wav2Vec2ProcessorWithLM(tokenizer=tokenizer, feature_extractor=feature_extractor, decoder=decoder)
|
||||||
|
|
||||||
def test_tokenizer(self):
|
def test_tokenizer(self):
|
||||||
feature_extractor = self.get_feature_extractor()
|
feature_extractor = self.get_feature_extractor()
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
|
|||||||
Reference in New Issue
Block a user