[Speech] Refactor Examples (#14040)
* adapt_examples * up * up * up * up * add auto models * finish
This commit is contained in:
committed by
GitHub
parent
2024faf171
commit
d5ff69fce9
@@ -59,3 +59,9 @@ SEWForCTC
|
|||||||
.. autoclass:: transformers.SEWForCTC
|
.. autoclass:: transformers.SEWForCTC
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
SEWForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SEWForSequenceClassification
|
||||||
|
:members: forward
|
||||||
|
|||||||
@@ -59,3 +59,8 @@ SEWDForCTC
|
|||||||
.. autoclass:: transformers.SEWDForCTC
|
.. autoclass:: transformers.SEWDForCTC
|
||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
SEWDForSequenceClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.SEWDForSequenceClassification
|
||||||
|
:members: forward
|
||||||
|
|||||||
@@ -1143,6 +1143,7 @@ if is_torch_available():
|
|||||||
[
|
[
|
||||||
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"SEWForCTC",
|
"SEWForCTC",
|
||||||
|
"SEWForSequenceClassification",
|
||||||
"SEWModel",
|
"SEWModel",
|
||||||
"SEWPreTrainedModel",
|
"SEWPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -1151,6 +1152,7 @@ if is_torch_available():
|
|||||||
[
|
[
|
||||||
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"SEWDForCTC",
|
"SEWDForCTC",
|
||||||
|
"SEWDForSequenceClassification",
|
||||||
"SEWDModel",
|
"SEWDModel",
|
||||||
"SEWDPreTrainedModel",
|
"SEWDPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2858,8 +2860,20 @@ if TYPE_CHECKING:
|
|||||||
RoFormerPreTrainedModel,
|
RoFormerPreTrainedModel,
|
||||||
load_tf_weights_in_roformer,
|
load_tf_weights_in_roformer,
|
||||||
)
|
)
|
||||||
from .models.sew import SEW_PRETRAINED_MODEL_ARCHIVE_LIST, SEWForCTC, SEWModel, SEWPreTrainedModel
|
from .models.sew import (
|
||||||
from .models.sew_d import SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST, SEWDForCTC, SEWDModel, SEWDPreTrainedModel
|
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
SEWForCTC,
|
||||||
|
SEWForSequenceClassification,
|
||||||
|
SEWModel,
|
||||||
|
SEWPreTrainedModel,
|
||||||
|
)
|
||||||
|
from .models.sew_d import (
|
||||||
|
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
SEWDForCTC,
|
||||||
|
SEWDForSequenceClassification,
|
||||||
|
SEWDModel,
|
||||||
|
SEWDPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
|
from .models.speech_encoder_decoder import SpeechEncoderDecoderModel
|
||||||
from .models.speech_to_text import (
|
from .models.speech_to_text import (
|
||||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
|||||||
@@ -476,6 +476,8 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
|||||||
# Model for Audio Classification mapping
|
# Model for Audio Classification mapping
|
||||||
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
("wav2vec2", "Wav2Vec2ForSequenceClassification"),
|
||||||
("hubert", "HubertForSequenceClassification"),
|
("hubert", "HubertForSequenceClassification"),
|
||||||
|
("sew", "SEWForSequenceClassification"),
|
||||||
|
("sew-d", "SEWDForSequenceClassification"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,12 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
from ...file_utils import (
|
||||||
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -36,6 +41,13 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
_CONFIG_FOR_DOC = "HubertConfig"
|
_CONFIG_FOR_DOC = "HubertConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-base-ls960"
|
_CHECKPOINT_FOR_DOC = "facebook/hubert-base-ls960"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
_SEQ_CLASS_CHECKPOINT = ("superb/hubert-base-superb-ks",)
|
||||||
|
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
|
|
||||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/hubert-base-ls960",
|
"facebook/hubert-base-ls960",
|
||||||
@@ -999,6 +1011,7 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
"""Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||||
HUBERT_START_DOCSTRING,
|
HUBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
||||||
class HubertForCTC(HubertPreTrainedModel):
|
class HubertForCTC(HubertPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1025,7 +1038,12 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
self.hubert.feature_extractor._freeze_parameters()
|
self.hubert.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=CausalLMOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
@@ -1041,41 +1059,6 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||||
config.vocab_size - 1]``.
|
config.vocab_size - 1]``.
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import Wav2Vec2Processor, HubertForCTC
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
|
||||||
>>> model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
|
||||||
>>> logits = model(input_values).logits
|
|
||||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
|
||||||
|
|
||||||
>>> transcription = processor.decode(predicted_ids[0])
|
|
||||||
|
|
||||||
>>> # compute loss
|
|
||||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
|
||||||
|
|
||||||
>>> # wrap processor as target processor to encode labels
|
|
||||||
>>> with processor.as_target_processor():
|
|
||||||
... labels = processor(target_transcription, return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1126,7 +1109,7 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutput(
|
return CausalLMOutput(
|
||||||
@@ -1141,8 +1124,8 @@ class HubertForCTC(HubertPreTrainedModel):
|
|||||||
""",
|
""",
|
||||||
HUBERT_START_DOCSTRING,
|
HUBERT_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
|
||||||
class HubertForSequenceClassification(HubertPreTrainedModel):
|
class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -1155,7 +1138,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
|
|
||||||
def freeze_feature_extractor(self):
|
def freeze_feature_extractor(self):
|
||||||
"""
|
"""
|
||||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
@@ -1163,7 +1145,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
self.hubert.feature_extractor._freeze_parameters()
|
self.hubert.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
|
|
||||||
def freeze_base_model(self):
|
def freeze_base_model(self):
|
||||||
"""
|
"""
|
||||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
@@ -1173,7 +1154,13 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
|
output_type=SequenceClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
@@ -1188,29 +1175,6 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
|
|
||||||
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
|
|
||||||
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
|
|
||||||
|
|
||||||
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
|
|
||||||
|
|
||||||
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
|
|
||||||
>>> logits = model(input_values).logits
|
|
||||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
|
|
||||||
|
|
||||||
>>> # compute loss
|
|
||||||
>>> target_label = "down"
|
|
||||||
>>> labels = torch.tensor([model.config.label2id[target_label]])
|
|
||||||
|
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1225,7 +1189,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_weighted_layer_sum:
|
if self.config.use_weighted_layer_sum:
|
||||||
hidden_states = outputs[1]
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
hidden_states = torch.stack(hidden_states, dim=1)
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
@@ -1248,7 +1212,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return SequenceClassifierOutput(
|
return SequenceClassifierOutput(
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
|||||||
_import_structure["modeling_sew"] = [
|
_import_structure["modeling_sew"] = [
|
||||||
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"SEWForCTC",
|
"SEWForCTC",
|
||||||
|
"SEWForSequenceClassification",
|
||||||
"SEWModel",
|
"SEWModel",
|
||||||
"SEWPreTrainedModel",
|
"SEWPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -36,7 +37,13 @@ if TYPE_CHECKING:
|
|||||||
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
|
from .configuration_sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_sew import SEW_PRETRAINED_MODEL_ARCHIVE_LIST, SEWForCTC, SEWModel, SEWPreTrainedModel
|
from .modeling_sew import (
|
||||||
|
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
SEWForCTC,
|
||||||
|
SEWForSequenceClassification,
|
||||||
|
SEWModel,
|
||||||
|
SEWPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -113,6 +113,11 @@ class SEWConfig(PretrainedConfig):
|
|||||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||||
instance of :class:`~transformers.SEWForCTC`.
|
instance of :class:`~transformers.SEWForCTC`.
|
||||||
|
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
||||||
|
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||||
|
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||||
|
Dimensionality of the projection before token mean-pooling for classification.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -161,6 +166,8 @@ class SEWConfig(PretrainedConfig):
|
|||||||
mask_feature_length=10,
|
mask_feature_length=10,
|
||||||
ctc_loss_reduction="sum",
|
ctc_loss_reduction="sum",
|
||||||
ctc_zero_infinity=False,
|
ctc_zero_infinity=False,
|
||||||
|
use_weighted_layer_sum=False,
|
||||||
|
classifier_proj_size=256,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
@@ -214,3 +221,7 @@ class SEWConfig(PretrainedConfig):
|
|||||||
# ctc loss
|
# ctc loss
|
||||||
self.ctc_loss_reduction = ctc_loss_reduction
|
self.ctc_loss_reduction = ctc_loss_reduction
|
||||||
self.ctc_zero_infinity = ctc_zero_infinity
|
self.ctc_zero_infinity = ctc_zero_infinity
|
||||||
|
|
||||||
|
# sequence classification
|
||||||
|
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
|||||||
@@ -21,12 +21,18 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
from ...file_utils import (
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_sew import SEWConfig
|
from .configuration_sew import SEWConfig
|
||||||
@@ -36,6 +42,13 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
_CONFIG_FOR_DOC = "SEWConfig"
|
_CONFIG_FOR_DOC = "SEWConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
||||||
|
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
|
|
||||||
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"asapp/sew-tiny-100k",
|
"asapp/sew-tiny-100k",
|
||||||
@@ -900,6 +913,7 @@ class SEWModel(SEWPreTrainedModel):
|
|||||||
"""SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
"""SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||||
SEW_START_DOCSTRING,
|
SEW_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
||||||
class SEWForCTC(SEWPreTrainedModel):
|
class SEWForCTC(SEWPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -926,7 +940,12 @@ class SEWForCTC(SEWPreTrainedModel):
|
|||||||
self.sew.feature_extractor._freeze_parameters()
|
self.sew.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=CausalLMOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
@@ -942,41 +961,6 @@ class SEWForCTC(SEWPreTrainedModel):
|
|||||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||||
config.vocab_size - 1]``.
|
config.vocab_size - 1]``.
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import Wav2Vec2Processor, SEWForCTC
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k")
|
|
||||||
>>> model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k")
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
|
||||||
>>> logits = model(input_values).logits
|
|
||||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
|
||||||
|
|
||||||
>>> transcription = processor.decode(predicted_ids[0])
|
|
||||||
|
|
||||||
>>> # compute loss
|
|
||||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
|
||||||
|
|
||||||
>>> # wrap processor as target processor to encode labels
|
|
||||||
>>> with processor.as_target_processor():
|
|
||||||
... labels = processor(target_transcription, return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1027,9 +1011,115 @@ class SEWForCTC(SEWPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutput(
|
return CausalLMOutput(
|
||||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
|
||||||
|
Keyword Spotting.
|
||||||
|
""",
|
||||||
|
SEW_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV_2_VEC_2->SEW
|
||||||
|
class SEWForSequenceClassification(SEWPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.sew = SEWModel(config)
|
||||||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||||
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||||||
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
|
will not be updated during training.
|
||||||
|
"""
|
||||||
|
self.sew.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.sew.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
|
output_type=SequenceClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.sew(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
if attention_mask is None:
|
||||||
|
pooled_output = hidden_states.mean(dim=1)
|
||||||
|
else:
|
||||||
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||||
|
hidden_states[~padding_mask] = 0.0
|
||||||
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||||
|
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
|||||||
_import_structure["modeling_sew_d"] = [
|
_import_structure["modeling_sew_d"] = [
|
||||||
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"SEWDForCTC",
|
"SEWDForCTC",
|
||||||
|
"SEWDForSequenceClassification",
|
||||||
"SEWDModel",
|
"SEWDModel",
|
||||||
"SEWDPreTrainedModel",
|
"SEWDPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -36,7 +37,13 @@ if TYPE_CHECKING:
|
|||||||
from .configuration_sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
|
from .configuration_sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_sew_d import SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST, SEWDForCTC, SEWDModel, SEWDPreTrainedModel
|
from .modeling_sew_d import (
|
||||||
|
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
SEWDForCTC,
|
||||||
|
SEWDForSequenceClassification,
|
||||||
|
SEWDModel,
|
||||||
|
SEWDPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -131,6 +131,11 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||||
instance of :class:`~transformers.SEWDForCTC`.
|
instance of :class:`~transformers.SEWDForCTC`.
|
||||||
|
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
||||||
|
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||||
|
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||||
|
Dimensionality of the projection before token mean-pooling for classification.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -186,6 +191,8 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
mask_feature_length=10,
|
mask_feature_length=10,
|
||||||
ctc_loss_reduction="sum",
|
ctc_loss_reduction="sum",
|
||||||
ctc_zero_infinity=False,
|
ctc_zero_infinity=False,
|
||||||
|
use_weighted_layer_sum=False,
|
||||||
|
classifier_proj_size=256,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
@@ -246,3 +253,7 @@ class SEWDConfig(PretrainedConfig):
|
|||||||
# ctc loss
|
# ctc loss
|
||||||
self.ctc_loss_reduction = ctc_loss_reduction
|
self.ctc_loss_reduction = ctc_loss_reduction
|
||||||
self.ctc_zero_infinity = ctc_zero_infinity
|
self.ctc_zero_infinity = ctc_zero_infinity
|
||||||
|
|
||||||
|
# sequence classification
|
||||||
|
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
|||||||
@@ -22,13 +22,18 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import _softmax_backward_data, nn
|
from torch import _softmax_backward_data, nn
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import CrossEntropyLoss, LayerNorm
|
||||||
|
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
from ...file_utils import (
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
add_code_sample_docstrings,
|
||||||
|
add_start_docstrings,
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_sew_d import SEWDConfig
|
from .configuration_sew_d import SEWDConfig
|
||||||
@@ -38,6 +43,12 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
||||||
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
||||||
|
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"asapp/sew-d-tiny-100k",
|
"asapp/sew-d-tiny-100k",
|
||||||
@@ -1405,6 +1416,7 @@ class SEWDModel(SEWDPreTrainedModel):
|
|||||||
"""SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
"""SEW-D Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||||
SEWD_START_DOCSTRING,
|
SEWD_START_DOCSTRING,
|
||||||
)
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
||||||
class SEWDForCTC(SEWDPreTrainedModel):
|
class SEWDForCTC(SEWDPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1431,7 +1443,12 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
|||||||
self.sew_d.feature_extractor._freeze_parameters()
|
self.sew_d.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=CausalLMOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
@@ -1447,41 +1464,6 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
|||||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||||
config.vocab_size - 1]``.
|
config.vocab_size - 1]``.
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
>>> import torch
|
|
||||||
>>> from transformers import Wav2Vec2Processor, SEWDForCTC
|
|
||||||
>>> from datasets import load_dataset
|
|
||||||
>>> import soundfile as sf
|
|
||||||
|
|
||||||
>>> processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-tiny-100k")
|
|
||||||
>>> model = SEWDForCTC.from_pretrained("asapp/sew-d-tiny-100k")
|
|
||||||
|
|
||||||
>>> def map_to_array(batch):
|
|
||||||
... speech, _ = sf.read(batch["file"])
|
|
||||||
... batch["speech"] = speech
|
|
||||||
... return batch
|
|
||||||
|
|
||||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
|
||||||
>>> ds = ds.map(map_to_array)
|
|
||||||
|
|
||||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
|
||||||
>>> logits = model(input_values).logits
|
|
||||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
|
||||||
|
|
||||||
>>> transcription = processor.decode(predicted_ids[0])
|
|
||||||
|
|
||||||
>>> # compute loss
|
|
||||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
|
||||||
|
|
||||||
>>> # wrap processor as target processor to encode labels
|
|
||||||
>>> with processor.as_target_processor():
|
|
||||||
... labels = processor(target_transcription, return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
>>> loss = model(input_values, labels=labels).loss
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1532,9 +1514,115 @@ class SEWDForCTC(SEWDPreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[1:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutput(
|
return CausalLMOutput(
|
||||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
SEWD Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
|
||||||
|
Keyword Spotting.
|
||||||
|
""",
|
||||||
|
SEWD_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEWD, wav2vec2->sew_d, WAV_2_VEC_2->SEWD
|
||||||
|
class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.sew_d = SEWDModel(config)
|
||||||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||||
|
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||||||
|
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
|
will not be updated during training.
|
||||||
|
"""
|
||||||
|
self.sew_d.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.sew_d.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||||
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
|
output_type=SequenceClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.sew_d(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
if attention_mask is None:
|
||||||
|
pooled_output = hidden_states.mean(dim=1)
|
||||||
|
else:
|
||||||
|
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||||
|
hidden_states[~padding_mask] = 0.0
|
||||||
|
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||||
|
|
||||||
|
logits = self.classifier(pooled_output)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return SequenceClassifierOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -46,6 +46,12 @@ _CONFIG_FOR_DOC = "Wav2Vec2Config"
|
|||||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
|
||||||
|
_SEQ_CLASS_CHECKPOINT = ("superb/wav2vec2-base-superb-ks",)
|
||||||
|
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/wav2vec2-base-960h",
|
"facebook/wav2vec2-base-960h",
|
||||||
"facebook/wav2vec2-large-960h",
|
"facebook/wav2vec2-large-960h",
|
||||||
@@ -1557,7 +1563,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return CausalLMOutput(
|
return CausalLMOutput(
|
||||||
@@ -1602,8 +1608,8 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class="Wav2Vec2FeatureExtractor",
|
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||||
checkpoint="superb/wav2vec2-base-superb-ks",
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
modality="audio",
|
modality="audio",
|
||||||
@@ -1636,7 +1642,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.config.use_weighted_layer_sum:
|
if self.config.use_weighted_layer_sum:
|
||||||
hidden_states = outputs[2]
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
hidden_states = torch.stack(hidden_states, dim=1)
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
@@ -1659,7 +1665,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
|||||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
output = (logits,) + outputs[2:]
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
return ((loss,) + output) if loss is not None else output
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
return SequenceClassifierOutput(
|
return SequenceClassifierOutput(
|
||||||
|
|||||||
@@ -3289,6 +3289,15 @@ class SEWForCTC:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SEWForSequenceClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class SEWModel:
|
class SEWModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -3315,6 +3324,15 @@ class SEWDForCTC:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class SEWDForSequenceClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class SEWDModel:
|
class SEWDModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import SEWForCTC, SEWModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
from transformers import (
|
||||||
|
SEWForCTC,
|
||||||
|
SEWForSequenceClassification,
|
||||||
|
SEWModel,
|
||||||
|
Wav2Vec2FeatureExtractor,
|
||||||
|
Wav2Vec2Processor,
|
||||||
|
)
|
||||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -219,6 +225,54 @@ class SEWModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||||
|
model = SEWForSequenceClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
# make sure that dropout is disabled
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
attention_mask[i, input_lengths[i] :] = 0
|
||||||
|
|
||||||
|
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||||
|
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||||
|
|
||||||
|
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||||
|
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||||
|
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||||
|
|
||||||
|
def check_seq_classifier_training(self, config, input_values, *args):
|
||||||
|
config.ctc_zero_infinity = True
|
||||||
|
model = SEWForSequenceClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# freeze everything but the classification head
|
||||||
|
model.freeze_base_model()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
|
||||||
|
loss = model(input_values, labels=labels).loss
|
||||||
|
self.parent.assertFalse(torch.isinf(loss).item())
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
model = SEWForCTC(config)
|
model = SEWForCTC(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -241,7 +295,7 @@ class SEWModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (SEWForCTC, SEWModel) if is_torch_available() else ()
|
all_model_classes = (SEWForCTC, SEWModel, SEWForSequenceClassification) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
@@ -328,6 +382,14 @@ class SEWModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
self.assertIsNotNone(hidden_states.grad)
|
self.assertIsNotNone(hidden_states.grad)
|
||||||
self.assertIsNotNone(attentions.grad)
|
self.assertIsNotNone(attentions.grad)
|
||||||
|
|
||||||
|
def test_seq_classifier_loss_inference(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_seq_classifier_train(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import SEWDForCTC, SEWDModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor
|
from transformers import (
|
||||||
|
SEWDForCTC,
|
||||||
|
SEWDForSequenceClassification,
|
||||||
|
SEWDModel,
|
||||||
|
Wav2Vec2FeatureExtractor,
|
||||||
|
Wav2Vec2Processor,
|
||||||
|
)
|
||||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -240,6 +246,54 @@ class SEWDModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||||
|
model = SEWDForSequenceClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
# make sure that dropout is disabled
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
attention_mask[i, input_lengths[i] :] = 0
|
||||||
|
|
||||||
|
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||||
|
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||||
|
|
||||||
|
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||||
|
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||||
|
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||||
|
|
||||||
|
def check_seq_classifier_training(self, config, input_values, *args):
|
||||||
|
config.ctc_zero_infinity = True
|
||||||
|
model = SEWDForSequenceClassification(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# freeze everything but the classification head
|
||||||
|
model.freeze_base_model()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
|
||||||
|
loss = model(input_values, labels=labels).loss
|
||||||
|
self.parent.assertFalse(torch.isinf(loss).item())
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
model = SEWDForCTC(config)
|
model = SEWDForCTC(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -262,7 +316,7 @@ class SEWDModelTester:
|
|||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class SEWDModelTest(ModelTesterMixin, unittest.TestCase):
|
class SEWDModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (SEWDForCTC, SEWDModel) if is_torch_available() else ()
|
all_model_classes = (SEWDForCTC, SEWDModel, SEWDForSequenceClassification) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_headmasking = False
|
test_headmasking = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|||||||
Reference in New Issue
Block a user