Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC (#10089)
* add wav2vec2CTC and deprecate for maskedlm * remove from docs
This commit is contained in:
committed by
Lysandre
parent
800f385d78
commit
02451cda74
@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
||||
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def test_inference_masked_lm_normal(self):
|
||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
@@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_masked_lm_normal_batched(self):
|
||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
@@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_masked_lm_robust_batched(self):
|
||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
Reference in New Issue
Block a user