[PretrainedFeatureExtractor] + Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2Tokenizer (#10324)

* push to show

* small improvement

* small improvement

* Update src/transformers/feature_extraction_utils.py

* Update src/transformers/feature_extraction_utils.py

* implement base

* add common tests

* make all tests pass for wav2vec2

* make padding work & add more tests

* finalize feature extractor utils

* add call method to feature extraction

* finalize feature processor

* finish tokenizer

* finish general processor design

* finish tests

* typo

* remove bogus file

* finish docstring

* add docs

* finish docs

* small fix

* correct docs

* save intermediate

* load changes

* apply changes

* apply changes to doc

* change tests

* apply surajs recommend

* final changes

* Apply suggestions from code review

* fix typo

* fix import

* correct docstring
This commit is contained in:
Patrick von Platen
2021-02-25 17:42:46 +03:00
committed by GitHub
parent 9dc7825744
commit cb38ffcc5e
33 changed files with 2252 additions and 176 deletions

View File

@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
class Wav2Vec2ModelTester:
@@ -324,17 +324,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def test_inference_ctc_normal(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(1)
input_values = tokenizer(input_speech, return_tensors="pt").input_values.to(torch_device)
input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
@@ -342,11 +341,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def test_inference_ctc_normal_batched(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
input_speech = self._load_datasamples(2)
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values.to(torch_device)
@@ -354,7 +353,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",
@@ -364,11 +363,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
def test_inference_ctc_robust_batched(self):
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)
@@ -377,7 +376,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
logits = model(input_values, attention_mask=attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
predicted_trans = tokenizer.batch_decode(predicted_ids)
predicted_trans = processor.batch_decode(predicted_ids)
EXPECTED_TRANSCRIPTIONS = [
"a man said to the universe sir i exist",