[PretrainedFeatureExtractor] + Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2Tokenizer (#10324)
* push to show * small improvement * small improvement * Update src/transformers/feature_extraction_utils.py * Update src/transformers/feature_extraction_utils.py * implement base * add common tests * make all tests pass for wav2vec2 * make padding work & add more tests * finalize feature extractor utils * add call method to feature extraction * finalize feature processor * finish tokenizer * finish general processor design * finish tests * typo * remove bogus file * finish docstring * add docs * finish docs * small fix * correct docs * save intermediate * load changes * apply changes * apply changes to doc * change tests * apply surajs recommend * final changes * Apply suggestions from code review * fix typo * fix import * correct docstring
This commit is contained in:
committed by
GitHub
parent
9dc7825744
commit
cb38ffcc5e
@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@@ -324,17 +324,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_ctc_normal(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
input_values = tokenizer(input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||
input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
@@ -342,11 +341,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_ctc_normal_batched(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
|
||||
@@ -354,7 +353,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
logits = model(input_values).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
@@ -364,11 +363,11 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_inference_ctc_robust_batched(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
@@ -377,7 +376,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
logits = model(input_values, attention_mask=attention_mask).logits
|
||||
|
||||
predicted_ids = torch.argmax(logits, dim=-1)
|
||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||
predicted_trans = processor.batch_decode(predicted_ids)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"a man said to the universe sir i exist",
|
||||
|
||||
Reference in New Issue
Block a user