From 3883e3a75e20cd7bbda200993b218d81b97859e7 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Mon, 20 Dec 2021 16:40:56 +0300 Subject: [PATCH] Add SD and SV heads for WavLM (#14847) * Add converted heads * Add dummies --- docs/source/model_doc/wavlm.rst | 14 + src/transformers/__init__.py | 4 + src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/models/wavlm/__init__.py | 4 + .../models/wavlm/configuration_wavlm.py | 24 ++ ...lm_original_s3prl_checkpoint_to_pytorch.py | 110 ++++++ .../models/wavlm/modeling_wavlm.py | 320 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 10 + tests/test_modeling_unispeech_sat.py | 8 +- tests/test_modeling_wavlm.py | 84 ++++- 10 files changed, 573 insertions(+), 7 deletions(-) create mode 100644 src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py diff --git a/docs/source/model_doc/wavlm.rst b/docs/source/model_doc/wavlm.rst index e371e3977f..59d030e83b 100644 --- a/docs/source/model_doc/wavlm.rst +++ b/docs/source/model_doc/wavlm.rst @@ -81,3 +81,17 @@ WavLMForSequenceClassification .. autoclass:: transformers.WavLMForSequenceClassification :members: forward + + +WavLMForAudioFrameClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.WavLMForAudioFrameClassification + :members: forward + + +WavLMForXVector +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.WavLMForXVector + :members: forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 0fe7eea849..5d59e4dd28 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1381,8 +1381,10 @@ if is_torch_available(): _import_structure["models.wavlm"].extend( [ "WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "WavLMForAudioFrameClassification", "WavLMForCTC", "WavLMForSequenceClassification", + "WavLMForXVector", "WavLMModel", "WavLMPreTrainedModel", ] @@ -3230,8 +3232,10 @@ if TYPE_CHECKING: ) from .models.wavlm import ( WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST, + WavLMForAudioFrameClassification, WavLMForCTC, WavLMForSequenceClassification, + WavLMForXVector, WavLMModel, WavLMPreTrainedModel, ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 978551fd07..003cc088b6 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -546,6 +546,7 @@ MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( # Model for Audio Classification mapping ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), + ("wavlm", "WavLMForAudioFrameClassification"), ] ) @@ -554,6 +555,7 @@ MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( # Model for Audio Classification mapping ("wav2vec2", "Wav2Vec2ForXVector"), ("unispeech-sat", "UniSpeechSatForXVector"), + ("wavlm", "WavLMForXVector"), ] ) diff --git a/src/transformers/models/wavlm/__init__.py b/src/transformers/models/wavlm/__init__.py index 7f8cc1041d..2cfc854919 100644 --- a/src/transformers/models/wavlm/__init__.py +++ b/src/transformers/models/wavlm/__init__.py @@ -27,8 +27,10 @@ _import_structure = { if is_torch_available(): _import_structure["modeling_wavlm"] = [ "WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "WavLMForAudioFrameClassification", "WavLMForCTC", "WavLMForSequenceClassification", + "WavLMForXVector", "WavLMModel", "WavLMPreTrainedModel", ] @@ -39,8 +41,10 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_wavlm import ( WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST, + WavLMForAudioFrameClassification, WavLMForCTC, WavLMForSequenceClassification, + WavLMForXVector, WavLMModel, WavLMPreTrainedModel, ) diff --git a/src/transformers/models/wavlm/configuration_wavlm.py b/src/transformers/models/wavlm/configuration_wavlm.py index d8b99e315f..42de770543 100644 --- a/src/transformers/models/wavlm/configuration_wavlm.py +++ b/src/transformers/models/wavlm/configuration_wavlm.py @@ -144,6 +144,17 @@ class WavLMConfig(PretrainedConfig): instance of :class:`~transformers.WavLMForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN` + module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers. + tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the + `XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`. + tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the + `XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`. + xvector_output_dim (:obj:`int`, `optional`, defaults to 512): + Dimensionality of the `XVector` embedding vectors. add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for warm-starting Wav2Vec2 for SpeechEncoderDecoder models. @@ -220,6 +231,10 @@ class WavLMConfig(PretrainedConfig): ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, num_ctc_classes=80, pad_token_id=0, bos_token_id=1, @@ -302,3 +317,12 @@ class WavLMConfig(PretrainedConfig): self.adapter_stride = adapter_stride self.num_adapter_layers = num_adapter_layers self.output_hidden_size = output_hidden_size or hidden_size + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim diff --git a/src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py new file mode 100644 index 0000000000..e41aa0099a --- /dev/null +++ b/src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + + +import argparse + +import torch + +from transformers import ( + Wav2Vec2FeatureExtractor, + WavLMConfig, + WavLMForAudioFrameClassification, + WavLMForSequenceClassification, + WavLMForXVector, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def convert_classification(base_model_name, hf_config, downstream_dict): + model = WavLMForSequenceClassification.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["projector.weight"] + model.projector.bias.data = downstream_dict["projector.bias"] + model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + return model + + +def convert_diarization(base_model_name, hf_config, downstream_dict): + model = WavLMForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config) + model.classifier.weight.data = downstream_dict["model.linear.weight"] + model.classifier.bias.data = downstream_dict["model.linear.bias"] + return model + + +def convert_xvector(base_model_name, hf_config, downstream_dict): + model = WavLMForXVector.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["connector.weight"] + model.projector.bias.data = downstream_dict["connector.bias"] + for i, kernel_size in enumerate(hf_config.tdnn_kernel): + model.tdnn[i].kernel.weight.data = downstream_dict[ + f"model.framelevel_feature_extractor.module.{i}.kernel.weight" + ] + model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"] + + model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"] + model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"] + model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"] + model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"] + model.objective.weight.data = downstream_dict["objective.W"] + return model + + +@torch.no_grad() +def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): + """ + Copy/paste/tweak model's weights to transformers design. + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + downstream_dict = checkpoint["Downstream"] + + hf_config = WavLMConfig.from_pretrained(config_path) + hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + base_model_name, return_attention_mask=True, do_normalize=False + ) + + arch = hf_config.architectures[0] + if arch.endswith("ForSequenceClassification"): + hf_model = convert_classification(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForAudioFrameClassification"): + hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForXVector"): + hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) + else: + raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}") + + if hf_config.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + + hf_feature_extractor.save_pretrained(model_dump_path) + hf_model.save_pretrained(model_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model." + ) + parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.") + parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.") + args = parser.parse_args() + convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index b55de4801f..d16472beb6 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -33,7 +33,7 @@ from ...file_utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, ) -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_wavlm import WavLMConfig @@ -48,6 +48,10 @@ _CHECKPOINT_FOR_DOC = "microsoft/wavlm-base" _SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base" _FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" +_SEQ_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus" +_FRAME_CLASS_CHECKPOINT = "microsoft/wavlm-base-plus-sd" +_XVECTOR_CHECKPOINT = "microsoft/wavlm-base-plus-sv" + _HIDDEN_STATES_START_POSITION = 2 WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -87,6 +91,38 @@ class WavLMBaseModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of :class:`~transformers.Wav2Vec2ForXVector`. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + embeddings: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], @@ -1447,3 +1483,285 @@ class WavLMForSequenceClassification(WavLMPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + WavLM Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAVLM_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM +class WavLMForAudioFrameClassification(WavLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wavlm = WavLMModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wavlm.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wavlm.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wavlm( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=None, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + hidden_states = hidden_states.unsqueeze(1) + hidden_states = nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAVLM_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->WavLM, wav2vec2->wavlm, WAV_2_VEC_2->WAVLM +class WavLMForXVector(WavLMPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wavlm = WavLMModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wavlm.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wavlm.parameters(): + param.requires_grad = False + + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wavlm( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 94d3199609..02afebeca0 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5177,6 +5177,11 @@ class Wav2Vec2PreTrainedModel: WAVLM_PRETRAINED_MODEL_ARCHIVE_LIST = None +class WavLMForAudioFrameClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class WavLMForCTC: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -5194,6 +5199,11 @@ class WavLMForSequenceClassification: requires_backends(self, ["torch"]) +class WavLMForXVector: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class WavLMModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_unispeech_sat.py b/tests/test_modeling_unispeech_sat.py index a48a28b1a0..6e50f1f003 100644 --- a/tests/test_modeling_unispeech_sat.py +++ b/tests/test_modeling_unispeech_sat.py @@ -863,10 +863,10 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3)) def test_inference_diarization(self): - model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to( + model = UniSpeechSatForAudioFrameClassification.from_pretrained("microsoft/unispeech-sat-base-plus-sd").to( torch_device ) - processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd") + processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sd") input_data = self._load_superb("sd", 4) inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000) @@ -892,8 +892,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase): self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3)) def test_inference_speaker_verification(self): - model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device) - processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv") + model = UniSpeechSatForXVector.from_pretrained("microsoft/unispeech-sat-base-plus-sv").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/unispeech-sat-base-plus-sv") input_data = self._load_superb("si", 4) inputs = processor(input_data["speech"], return_tensors="pt", padding=True) diff --git a/tests/test_modeling_wavlm.py b/tests/test_modeling_wavlm.py index f1b40308bd..d0be0d6886 100644 --- a/tests/test_modeling_wavlm.py +++ b/tests/test_modeling_wavlm.py @@ -31,7 +31,14 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init if is_torch_available(): import torch - from transformers import Wav2Vec2FeatureExtractor, WavLMForCTC, WavLMForSequenceClassification, WavLMModel + from transformers import ( + Wav2Vec2FeatureExtractor, + WavLMForAudioFrameClassification, + WavLMForCTC, + WavLMForSequenceClassification, + WavLMForXVector, + WavLMModel, + ) class WavLMModelTester: @@ -60,6 +67,10 @@ class WavLMModelTester: initializer_range=0.02, vocab_size=32, do_stable_layer_norm=False, + tdnn_dim=(32, 32), + tdnn_kernel=(3, 3), + tdnn_dilation=(1, 1), + xvector_output_dim=32, scope=None, ): self.parent = parent @@ -85,6 +96,10 @@ class WavLMModelTester: self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm + self.tdnn_dim = tdnn_dim + self.tdnn_kernel = tdnn_kernel + self.tdnn_dilation = tdnn_dilation + self.xvector_output_dim = xvector_output_dim self.scope = scope output_seq_length = self.seq_length @@ -121,6 +136,10 @@ class WavLMModelTester: hidden_act=self.hidden_act, initializer_range=self.initializer_range, vocab_size=self.vocab_size, + tdnn_dim=self.tdnn_dim, + tdnn_kernel=self.tdnn_kernel, + tdnn_dilation=self.tdnn_dilation, + xvector_output_dim=self.xvector_output_dim, ) def create_and_check_model(self, config, input_values, attention_mask): @@ -285,7 +304,11 @@ class WavLMModelTester: @require_torch class WavLMModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (WavLMForCTC, WavLMModel, WavLMForSequenceClassification) if is_torch_available() else () + all_model_classes = ( + (WavLMForCTC, WavLMModel, WavLMForAudioFrameClassification, WavLMForSequenceClassification, WavLMForXVector) + if is_torch_available() + else () + ) test_pruning = False test_headmasking = False test_torchscript = False @@ -398,6 +421,7 @@ class WavLMModelTest(ModelTesterMixin, unittest.TestCase): "feature_projection.projection.bias", "label_embeddings_concat", "rel_attn_embed", + "objective.weight", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -446,6 +470,11 @@ class WavLMModelIntegrationTest(unittest.TestCase): return [x["array"] for x in speech_samples] + def _load_superb(self, task, num_samples): + ds = load_dataset("anton-l/superb_dummy", task, split="test") + + return ds[:num_samples] + def test_inference_base(self): model = WavLMModel.from_pretrained("microsoft/wavlm-base-plus").to(torch_device) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( @@ -491,3 +520,54 @@ class WavLMModelIntegrationTest(unittest.TestCase): [[[0.1612, 0.4314], [0.1690, 0.4344]], [[0.2086, 0.1396], [0.3014, 0.0903]]] ) self.assertTrue(torch.allclose(hidden_states_slice, EXPECTED_HIDDEN_STATES_SLICE, rtol=1e-2)) + + def test_inference_diarization(self): + model = WavLMForAudioFrameClassification.from_pretrained("microsoft/wavlm-base-plus-sd").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sd") + input_data = self._load_superb("sd", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + # labels is a one-hot array of shape (num_frames, num_speakers) + labels = (outputs.logits > 0).long() + + # s3prl logits for the same batch + expected_logits = torch.tensor( + [ + [[-5.9566, -8.6554], [-5.7137, -8.9386], [-5.7906, -7.0973], [-5.7829, -5.9999]], + [[-5.2086, -7.7878], [-4.8890, -7.9312], [-4.2004, -3.9101], [-5.4480, -4.6932]], + [[-4.6105, -6.7178], [-5.1930, -6.1635], [-2.6228, -4.1123], [-2.7646, -3.1576]], + [[-4.4477, -7.9206], [-3.9339, -7.3707], [-4.9528, -4.8242], [-3.6921, -2.9687]], + ], + device=torch_device, + ) + self.assertEqual(labels[0, :, 0].sum(), 258) + self.assertEqual(labels[0, :, 1].sum(), 647) + self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3)) + + def test_inference_speaker_verification(self): + model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("microsoft/wavlm-base-plus-sv") + input_data = self._load_superb("si", 4) + + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + labels = torch.tensor([5, 1, 1, 3], device=torch_device).T + + with torch.no_grad(): + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + outputs = model(input_values, attention_mask=attention_mask, labels=labels) + embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1) + + cosine_sim = torch.nn.CosineSimilarity(dim=-1) + # id10002 vs id10002 + self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9787, 3) + # id10006 vs id10002 + self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.5064, 3) + # id10002 vs id10004 + self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.4780, 3) + + self.assertAlmostEqual(outputs.loss.item(), 18.4154, 3)