Add Wav2Vec2 & Hubert ForSequenceClassification (#13153)
* Add hubert classifier + tests * Add hubert classifier + tests * Dummies for all classification tests * Wav2Vec2 classifier + ER test * Fix hubert integration tests * Add hubert IC * Pass tests for all classification tasks on Hubert * Pass all tests + copies * Move models to the SUPERB org
This commit is contained in:
@@ -64,6 +64,14 @@ HubertForCTC
|
||||
.. autoclass:: transformers.HubertForCTC
|
||||
:members: forward
|
||||
|
||||
|
||||
HubertForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.HubertForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
TFHubertModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -96,6 +96,14 @@ Wav2Vec2ForCTC
|
||||
.. autoclass:: transformers.Wav2Vec2ForCTC
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2ForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -818,6 +818,7 @@ if is_torch_available():
|
||||
[
|
||||
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"HubertForCTC",
|
||||
"HubertForSequenceClassification",
|
||||
"HubertModel",
|
||||
"HubertPreTrainedModel",
|
||||
]
|
||||
@@ -1128,6 +1129,7 @@ if is_torch_available():
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@@ -2424,6 +2426,7 @@ if TYPE_CHECKING:
|
||||
from .models.hubert import (
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
HubertForCTC,
|
||||
HubertForSequenceClassification,
|
||||
HubertModel,
|
||||
HubertPreTrainedModel,
|
||||
)
|
||||
@@ -2681,6 +2684,7 @@ if TYPE_CHECKING:
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_hubert"] = [
|
||||
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"HubertForCTC",
|
||||
"HubertForSequenceClassification",
|
||||
"HubertModel",
|
||||
"HubertPreTrainedModel",
|
||||
]
|
||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_hubert import (
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
HubertForCTC,
|
||||
HubertForSequenceClassification,
|
||||
HubertModel,
|
||||
HubertPreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.HubertForCTC`.
|
||||
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
||||
instance of :class:`~transformers.HubertForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
@@ -165,6 +170,8 @@ class HubertConfig(PretrainedConfig):
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
@@ -197,6 +204,8 @@ class HubertConfig(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Hubert checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import HubertConfig, HubertForSequenceClassification, Wav2Vec2FeatureExtractor, logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SUPPORTED_MODELS = ["UtteranceLevel"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
|
||||
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
|
||||
|
||||
downstream_dict = checkpoint["Downstream"]
|
||||
|
||||
hf_congfig = HubertConfig.from_pretrained(config_path)
|
||||
hf_model = HubertForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
|
||||
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_model_name, return_attention_mask=True, do_normalize=False
|
||||
)
|
||||
|
||||
if hf_congfig.use_weighted_layer_sum:
|
||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||
|
||||
hf_model.projector.weight.data = downstream_dict["projector.weight"]
|
||||
hf_model.projector.bias.data = downstream_dict["projector.bias"]
|
||||
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||
|
||||
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||
hf_model.save_pretrained(model_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
|
||||
)
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
|
||||
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
|
||||
args = parser.parse_args()
|
||||
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
|
||||
@@ -20,12 +20,13 @@ import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_hubert import HubertConfig
|
||||
@@ -735,6 +736,18 @@ class HubertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
||||
)
|
||||
# these two operations makes sure that all values before the output lengths idxs are attended to
|
||||
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
return attention_mask
|
||||
|
||||
|
||||
HUBERT_START_DOCSTRING = r"""
|
||||
Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units
|
||||
@@ -904,19 +917,8 @@ class HubertModel(HubertPreTrainedModel):
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
|
||||
)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
attention_mask[
|
||||
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
|
||||
] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
hidden_states = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
||||
@@ -1070,3 +1072,128 @@ class HubertForCTC(HubertPreTrainedModel):
|
||||
return CausalLMOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
|
||||
SUPERB Keyword Spotting.
|
||||
""",
|
||||
HUBERT_START_DOCSTRING,
|
||||
)
|
||||
class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.hubert = HubertModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||||
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.hubert.feature_extractor._freeze_parameters()
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.hubert.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
|
||||
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
|
||||
|
||||
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
|
||||
|
||||
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_label = "down"
|
||||
>>> labels = torch.tensor([model.config.label2id[target_label]])
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.hubert(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[1]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
if attention_mask is None:
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -133,6 +133,11 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.Wav2Vec2ForCTC`.
|
||||
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
|
||||
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
@@ -191,6 +196,8 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
diversity_loss_weight=0.1,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
@@ -223,6 +230,8 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Hubert checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SUPPORTED_MODELS = ["UtteranceLevel"]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
|
||||
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
|
||||
|
||||
downstream_dict = checkpoint["Downstream"]
|
||||
|
||||
hf_congfig = Wav2Vec2Config.from_pretrained(config_path)
|
||||
hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
|
||||
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_model_name, return_attention_mask=True, do_normalize=False
|
||||
)
|
||||
|
||||
if hf_congfig.use_weighted_layer_sum:
|
||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||
|
||||
hf_model.projector.weight.data = downstream_dict["projector.weight"]
|
||||
hf_model.projector.bias.data = downstream_dict["projector.bias"]
|
||||
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||
|
||||
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||
hf_model.save_pretrained(model_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
|
||||
)
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
|
||||
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
|
||||
args = parser.parse_args()
|
||||
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
|
||||
@@ -83,9 +83,6 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
"""
|
||||
Every array in the list is normalized to have zero mean and unit variance
|
||||
"""
|
||||
if isinstance(input_values[0], np.ndarray):
|
||||
input_values = [x.astype(np.float32) for x in input_values]
|
||||
|
||||
normed_input_values = [
|
||||
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
|
||||
]
|
||||
@@ -205,6 +202,9 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
padded_input_values = padded_inputs["input_values"]
|
||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
||||
|
||||
if isinstance(padded_inputs["input_values"][0], np.ndarray):
|
||||
padded_inputs["input_values"] = [x.astype(np.float32) for x in padded_inputs["input_values"]]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
|
||||
|
||||
@@ -22,6 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...deepspeed import is_deepspeed_zero3_enabled
|
||||
@@ -31,7 +32,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
@@ -1057,7 +1058,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask correponding to feature vectors
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
@@ -1527,3 +1528,126 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
return CausalLMOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
|
||||
SUPERB Keyword Spotting.
|
||||
""",
|
||||
WAV_2_VEC_2_START_DOCSTRING,
|
||||
)
|
||||
class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
|
||||
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.wav2vec2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> import torch
|
||||
>>> from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
|
||||
>>> from datasets import load_dataset
|
||||
|
||||
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
|
||||
>>> model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks")
|
||||
|
||||
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
|
||||
|
||||
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_label = "down"
|
||||
>>> labels = torch.tensor([model.config.label2id[target_label]])
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
# End copy
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[2]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
if attention_mask is None:
|
||||
pooled_output = hidden_states.mean(dim=1)
|
||||
else:
|
||||
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
|
||||
hidden_states[~padding_mask] = 0.0
|
||||
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
|
||||
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -1863,6 +1863,15 @@ class HubertForCTC:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class HubertForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HubertModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@@ -3473,6 +3482,15 @@ class Wav2Vec2ForPreTraining:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2ForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@@ -31,7 +31,13 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import HubertForCTC, HubertModel, Wav2Vec2Processor
|
||||
from transformers import (
|
||||
HubertForCTC,
|
||||
HubertForSequenceClassification,
|
||||
HubertModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
from transformers.models.hubert.modeling_hubert import _compute_mask_indices
|
||||
|
||||
|
||||
@@ -187,7 +193,32 @@ class HubertModelTester:
|
||||
self.parent.assertTrue(isinstance(sum_loss, float))
|
||||
self.parent.assertTrue(isinstance(mean_loss, float))
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = HubertForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_ctc_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = HubertForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -216,6 +247,29 @@ class HubertModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = HubertForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = HubertForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -238,7 +292,7 @@ class HubertModelTester:
|
||||
|
||||
@require_torch
|
||||
class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else ()
|
||||
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -258,9 +312,17 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_ctc_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_training(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -371,7 +433,7 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (HubertForCTC, HubertModel) if is_torch_available() else ()
|
||||
all_model_classes = (HubertForCTC, HubertForSequenceClassification, HubertModel) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -397,9 +459,17 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_ctc_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_training(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -557,6 +627,13 @@ class HubertModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def _load_superb(self, task, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||
|
||||
return ds[:num_samples]
|
||||
|
||||
def test_inference_ctc_batched(self):
|
||||
model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").to(torch_device)
|
||||
processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft", do_lower_case=True)
|
||||
@@ -579,3 +656,95 @@ class HubertModelIntegrationTest(unittest.TestCase):
|
||||
"sweat covered brion's body trickling into the tight loin cloth that was the only garment he wore",
|
||||
]
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_keyword_spotting(self):
|
||||
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
|
||||
input_data = self._load_superb("ks", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
|
||||
|
||||
expected_labels = [2, 6, 10, 9]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([7.6692, 17.7795, 11.1562, 11.8232], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_intent_classification(self):
|
||||
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ic").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ic")
|
||||
input_data = self._load_superb("ic", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
|
||||
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
|
||||
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
|
||||
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
|
||||
|
||||
expected_labels_action = [1, 0, 4, 3]
|
||||
expected_logits_action = torch.tensor([5.9052, 12.5865, 4.4840, 10.0240], device=torch_device)
|
||||
expected_labels_object = [1, 10, 3, 4]
|
||||
expected_logits_object = torch.tensor([5.5316, 11.7946, 8.1672, 23.2415], device=torch_device)
|
||||
expected_labels_location = [0, 0, 0, 1]
|
||||
expected_logits_location = torch.tensor([5.2053, 8.9577, 10.0447, 8.1481], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
|
||||
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
|
||||
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
|
||||
|
||||
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
|
||||
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=3e-1))
|
||||
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=3e-1))
|
||||
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=3e-1))
|
||||
|
||||
def test_inference_speaker_identification(self):
|
||||
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-sid").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-sid")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
output_logits = []
|
||||
with torch.no_grad():
|
||||
for example in input_data["speech"]:
|
||||
input = processor(example, return_tensors="pt", padding=True)
|
||||
output = model(input.input_values.to(torch_device), attention_mask=None)
|
||||
output_logits.append(output.logits[0])
|
||||
output_logits = torch.stack(output_logits)
|
||||
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
|
||||
|
||||
expected_labels = [5, 1, 1, 3]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([78231.5547, 123166.6094, 122785.4141, 84851.2969], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=10))
|
||||
|
||||
def test_inference_emotion_recognition(self):
|
||||
model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-er").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-er")
|
||||
input_data = self._load_superb("er", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
|
||||
|
||||
expected_labels = [1, 1, 2, 2]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([2.8384, 2.3389, 3.8564, 4.5558], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-1))
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" Testing suite for the PyTorch Wav2Vec2 model. """
|
||||
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
@@ -36,6 +35,7 @@ if is_torch_available():
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
@@ -194,7 +194,32 @@ class Wav2Vec2ModelTester:
|
||||
self.parent.assertTrue(isinstance(sum_loss, float))
|
||||
self.parent.assertTrue(isinstance(mean_loss, float))
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
def check_seq_classifier_loss(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.long)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0
|
||||
|
||||
masked_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss.item()
|
||||
unmasked_loss = model(input_values, labels=labels).loss.item()
|
||||
|
||||
self.parent.assertTrue(isinstance(masked_loss, float))
|
||||
self.parent.assertTrue(isinstance(unmasked_loss, float))
|
||||
self.parent.assertTrue(masked_loss != unmasked_loss)
|
||||
|
||||
def check_ctc_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -223,6 +248,29 @@ class Wav2Vec2ModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_seq_classifier_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForSequenceClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -246,7 +294,9 @@ class Wav2Vec2ModelTester:
|
||||
@require_torch
|
||||
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -267,9 +317,17 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_ctc_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_training(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -384,7 +442,9 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining) if is_torch_available() else ()
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
@@ -411,9 +471,17 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
def test_seq_classifier_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
self.model_tester.check_seq_classifier_loss(*config_and_inputs)
|
||||
|
||||
def test_ctc_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_training(*config_and_inputs)
|
||||
|
||||
def test_seq_classifier_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -691,6 +759,13 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def _load_superb(self, task, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("anton-l/superb_dummy", task, split="test")
|
||||
|
||||
return ds[:num_samples]
|
||||
|
||||
def test_inference_ctc_normal(self):
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
@@ -795,7 +870,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
# fmt: off
|
||||
expected_cosine_sim_masked = torch.tensor(
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831, 0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651, 0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371, 0.6997],
|
||||
[0.7458, 0.7188, 0.6418, 0.3729, 0.3741, 0.3694, 0.3110, 0.2257, 0.4403, 0.5415, 0.3950, 0.3701, 0.8831,
|
||||
0.8613, 0.5229, 0.6696, 0.7206, 0.7877, 0.6758, 0.8746, 0.6596, 0.6282, 0.6178, 0.5839, 0.5926, 0.6651,
|
||||
0.4635, 0.6332, 0.6572, 0.8776, 0.4999, 0.7001, 0.7257, 0.5098, 0.6229, 0.4566, 0.5261, 0.6363, 0.5371,
|
||||
0.6997],
|
||||
device=torch_device,
|
||||
)
|
||||
# fmt: on
|
||||
@@ -913,3 +991,92 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
expected_loss = 62.5170
|
||||
|
||||
self.assertTrue(abs(outputs.loss.item() - expected_loss) < 1e-3)
|
||||
|
||||
def test_inference_keyword_spotting(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ks").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ks")
|
||||
input_data = self._load_superb("ks", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
|
||||
|
||||
expected_labels = [7, 6, 10, 9]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([6.1186, 11.8961, 10.2931, 6.0898], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_intent_classification(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
|
||||
input_data = self._load_superb("ic", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
|
||||
predicted_logits_action, predicted_ids_action = torch.max(outputs.logits[:, :6], dim=-1)
|
||||
predicted_logits_object, predicted_ids_object = torch.max(outputs.logits[:, 6:20], dim=-1)
|
||||
predicted_logits_location, predicted_ids_location = torch.max(outputs.logits[:, 20:24], dim=-1)
|
||||
|
||||
expected_labels_action = [0, 0, 2, 3]
|
||||
expected_logits_action = torch.tensor([0.4568, 11.0848, 1.6621, 9.3841], device=torch_device)
|
||||
expected_labels_object = [3, 10, 3, 4]
|
||||
expected_logits_object = torch.tensor([1.5322, 10.7094, 5.2469, 22.1318], device=torch_device)
|
||||
expected_labels_location = [0, 0, 0, 1]
|
||||
expected_logits_location = torch.tensor([1.5335, 6.5096, 10.5704, 11.0569], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids_action.tolist(), expected_labels_action)
|
||||
self.assertListEqual(predicted_ids_object.tolist(), expected_labels_object)
|
||||
self.assertListEqual(predicted_ids_location.tolist(), expected_labels_location)
|
||||
|
||||
self.assertTrue(torch.allclose(predicted_logits_action, expected_logits_action, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(predicted_logits_object, expected_logits_object, atol=1e-2))
|
||||
self.assertTrue(torch.allclose(predicted_logits_location, expected_logits_location, atol=1e-2))
|
||||
|
||||
def test_inference_speaker_identification(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-sid").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-sid")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
output_logits = []
|
||||
with torch.no_grad():
|
||||
for example in input_data["speech"]:
|
||||
input = processor(example, return_tensors="pt", padding=True)
|
||||
output = model(input.input_values.to(torch_device), attention_mask=None)
|
||||
output_logits.append(output.logits[0])
|
||||
output_logits = torch.stack(output_logits)
|
||||
predicted_logits, predicted_ids = torch.max(output_logits, dim=-1)
|
||||
|
||||
expected_labels = [251, 1, 1, 3]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([37.5627, 71.6362, 64.2419, 31.7778], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
def test_inference_emotion_recognition(self):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-er").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")
|
||||
input_data = self._load_superb("er", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
predicted_logits, predicted_ids = torch.max(outputs.logits, dim=-1)
|
||||
|
||||
expected_labels = [1, 1, 2, 2]
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor([2.1722, 3.0779, 8.0287, 6.6797], device=torch_device)
|
||||
|
||||
self.assertListEqual(predicted_ids.tolist(), expected_labels)
|
||||
self.assertTrue(torch.allclose(predicted_logits, expected_logits, atol=1e-2))
|
||||
|
||||
@@ -122,6 +122,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||
"TFRagTokenForGeneration",
|
||||
"Wav2Vec2ForCTC",
|
||||
"HubertForCTC",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"HubertForSequenceClassification",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLNetForQuestionAnswering",
|
||||
"SeparableConv1D",
|
||||
|
||||
Reference in New Issue
Block a user