Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC (#10089)

* add wav2vec2CTC and deprecate for maskedlm

* remove from docs
This commit is contained in:
Patrick von Platen
2021-02-09 11:49:02 +03:00
committed by Lysandre
parent 800f385d78
commit 02451cda74
8 changed files with 100 additions and 10 deletions

View File

@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
if is_torch_available():
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
class Wav2Vec2ModelTester:
@@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
@require_torch
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
test_pruning = False
test_headmasking = False
test_torchscript = False
@@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
return ds["speech"][:num_samples]
def test_inference_masked_lm_normal(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
@@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_normal_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
model.to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
@@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
def test_inference_masked_lm_robust_batched(self):
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
input_speech = self._load_datasamples(4)