Deprecate Wav2Vec2ForMaskedLM and add Wav2Vec2ForCTC (#10089)
* add wav2vec2CTC and deprecate for maskedlm * remove from docs
This commit is contained in:
committed by
GitHub
parent
ba542ffb49
commit
b972125ced
@@ -58,8 +58,8 @@ Wav2Vec2Model
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
Wav2Vec2ForMaskedLM
|
Wav2Vec2ForCTC
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.Wav2Vec2ForMaskedLM
|
.. autoclass:: transformers.Wav2Vec2ForCTC
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|||||||
@@ -367,6 +367,7 @@ if is_torch_available():
|
|||||||
_import_structure["models.wav2vec2"].extend(
|
_import_structure["models.wav2vec2"].extend(
|
||||||
[
|
[
|
||||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Wav2Vec2ForCTC",
|
||||||
"Wav2Vec2ForMaskedLM",
|
"Wav2Vec2ForMaskedLM",
|
||||||
"Wav2Vec2Model",
|
"Wav2Vec2Model",
|
||||||
"Wav2Vec2PreTrainedModel",
|
"Wav2Vec2PreTrainedModel",
|
||||||
@@ -1813,6 +1814,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.wav2vec2 import (
|
from .models.wav2vec2 import (
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
Wav2Vec2ForCTC,
|
||||||
Wav2Vec2ForMaskedLM,
|
Wav2Vec2ForMaskedLM,
|
||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2PreTrainedModel,
|
Wav2Vec2PreTrainedModel,
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ if is_torch_available():
|
|||||||
_import_structure["modeling_wav2vec2"] = [
|
_import_structure["modeling_wav2vec2"] = [
|
||||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"Wav2Vec2ForMaskedLM",
|
"Wav2Vec2ForMaskedLM",
|
||||||
|
"Wav2Vec2ForCTC",
|
||||||
"Wav2Vec2Model",
|
"Wav2Vec2Model",
|
||||||
"Wav2Vec2PreTrainedModel",
|
"Wav2Vec2PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -41,6 +42,7 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_wav2vec2 import (
|
from .modeling_wav2vec2 import (
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
Wav2Vec2ForCTC,
|
||||||
Wav2Vec2ForMaskedLM,
|
Wav2Vec2ForMaskedLM,
|
||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2PreTrainedModel,
|
Wav2Vec2PreTrainedModel,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import argparse
|
|||||||
import fairseq
|
import fairseq
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging
|
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
@@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_
|
|||||||
"""
|
"""
|
||||||
Copy/paste/tweak model's weights to transformers design.
|
Copy/paste/tweak model's weights to transformers design.
|
||||||
"""
|
"""
|
||||||
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config())
|
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
|
||||||
|
|
||||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
[checkpoint_path], arg_overrides={"data": dict_path}
|
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
""" PyTorch Wav2Vec2 model. """
|
""" PyTorch Wav2Vec2 model. """
|
||||||
|
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -24,7 +25,7 @@ from torch import nn
|
|||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||||
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||||
@@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning
|
||||||
|
)
|
||||||
|
|
||||||
self.wav2vec2 = Wav2Vec2Model(config)
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
@@ -729,3 +734,77 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||||
|
WAV_2_VEC_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||||
|
TODO(PVP): Fill out when adding training
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import soundfile as sf
|
||||||
|
|
||||||
|
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
>>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
|
|
||||||
|
>>> def map_to_array(batch):
|
||||||
|
>>> speech, _ = sf.read(batch["file"])
|
||||||
|
>>> batch["speech"] = speech
|
||||||
|
>>> return batch
|
||||||
|
|
||||||
|
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> ds = ds.map(map_to_array)
|
||||||
|
|
||||||
|
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||||
|
>>> logits = model(input_values).logits
|
||||||
|
|
||||||
|
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
|
>>> transcription = tokenizer.decode(predicted_ids[0])
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = self.wav2vec2(
|
||||||
|
input_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return output
|
||||||
|
|
||||||
|
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||||
|
|||||||
@@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
|
|||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ForCTC:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_pytorch(self)
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ForMaskedLM:
|
class Wav2Vec2ForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_pytorch(self)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ModelTester:
|
class Wav2Vec2ModelTester:
|
||||||
@@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
@@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
return ds["speech"][:num_samples]
|
return ds["speech"][:num_samples]
|
||||||
|
|
||||||
def test_inference_masked_lm_normal(self):
|
def test_inference_masked_lm_normal(self):
|
||||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
|
|
||||||
@@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
def test_inference_masked_lm_normal_batched(self):
|
def test_inference_masked_lm_normal_batched(self):
|
||||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
|
|
||||||
@@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
def test_inference_masked_lm_robust_batched(self):
|
def test_inference_masked_lm_robust_batched(self):
|
||||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
|
|||||||
@@ -118,6 +118,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"TFMT5EncoderModel",
|
"TFMT5EncoderModel",
|
||||||
"TFOpenAIGPTDoubleHeadsModel",
|
"TFOpenAIGPTDoubleHeadsModel",
|
||||||
"TFT5EncoderModel",
|
"TFT5EncoderModel",
|
||||||
|
"Wav2Vec2ForCTC",
|
||||||
"XLMForQuestionAnswering",
|
"XLMForQuestionAnswering",
|
||||||
"XLMProphetNetDecoder",
|
"XLMProphetNetDecoder",
|
||||||
"XLMProphetNetEncoder",
|
"XLMProphetNetEncoder",
|
||||||
@@ -370,6 +371,7 @@ DEPRECATED_OBJECTS = [
|
|||||||
"TFBartPretrainedModel",
|
"TFBartPretrainedModel",
|
||||||
"TextDataset",
|
"TextDataset",
|
||||||
"TextDatasetForNextSentencePrediction",
|
"TextDatasetForNextSentencePrediction",
|
||||||
|
"Wav2Vec2ForMaskedLM",
|
||||||
"glue_compute_metrics",
|
"glue_compute_metrics",
|
||||||
"glue_convert_examples_to_features",
|
"glue_convert_examples_to_features",
|
||||||
"glue_output_modes",
|
"glue_output_modes",
|
||||||
|
|||||||
Reference in New Issue
Block a user