From 48463ebb33c4a3f4035dbdaf55dc43778304f318 Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 16 Dec 2021 19:22:14 +0300 Subject: [PATCH] Add Speaker Diarization and Verification heads (#14723) * Models * Squashed commit of the following: commit 72278e1e931a16d0879acc77f65762f3364833d0 Author: anton-l Date: Fri Dec 10 21:45:08 2021 +0300 * Add unispeech heads * Add sd/sv automodels * Docs cleanup * Fix docstrings * rename xvector classes * examples * Tests cleanup * Style * Better checkpoints for tests * leftover docs * apply review suggestions * Style + init tests * Update unispeech-sat tdnn downsampling --- docs/source/model_doc/auto.rst | 14 + docs/source/model_doc/unispeech_sat.rst | 14 + docs/source/model_doc/wav2vec2.rst | 14 + src/transformers/__init__.py | 12 + src/transformers/file_utils.py | 54 +++ src/transformers/models/auto/__init__.py | 4 + src/transformers/models/auto/modeling_auto.py | 36 ++ .../models/hubert/modeling_hubert.py | 4 +- src/transformers/models/sew/modeling_sew.py | 4 +- .../models/sew_d/modeling_sew_d.py | 4 +- .../models/unispeech/modeling_unispeech.py | 4 +- .../models/unispeech_sat/__init__.py | 4 + .../configuration_unispeech_sat.py | 25 +- ...ch_original_s3prl_checkpoint_to_pytorch.py | 110 ++++++ .../unispeech_sat/modeling_unispeech_sat.py | 322 ++++++++++++++++- src/transformers/models/wav2vec2/__init__.py | 4 + .../models/wav2vec2/configuration_wav2vec2.py | 29 +- ...c2_original_s3prl_checkpoint_to_pytorch.py | 65 +++- .../models/wav2vec2/modeling_wav2vec2.py | 326 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 44 +++ tests/test_modeling_unispeech_sat.py | 106 +++++- tests/test_modeling_wav2vec2.py | 108 +++++- utils/update_metadata.py | 6 + 23 files changed, 1280 insertions(+), 33 deletions(-) create mode 100644 src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index ad69cf4dc0..da6aa1fcfc 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -181,6 +181,13 @@ AutoModelForAudioClassification :members: +AutoModelForAudioFrameClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AutoModelForAudioFrameClassification + :members: + + AutoModelForCTC ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -195,6 +202,13 @@ AutoModelForSpeechSeq2Seq :members: +AutoModelForAudioXVector +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.AutoModelForAudioXVector + :members: + + AutoModelForObjectDetection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/unispeech_sat.rst b/docs/source/model_doc/unispeech_sat.rst index 5418642371..96c4311b06 100644 --- a/docs/source/model_doc/unispeech_sat.rst +++ b/docs/source/model_doc/unispeech_sat.rst @@ -85,6 +85,20 @@ UniSpeechSatForSequenceClassification :members: forward +UniSpeechSatForAudioFrameClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.UniSpeechSatForAudioFrameClassification + :members: forward + + +UniSpeechSatForXVector +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.UniSpeechSatForXVector + :members: forward + + UniSpeechSatForPreTraining ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index 3ac721b2f9..b988f11408 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -114,6 +114,20 @@ Wav2Vec2ForSequenceClassification :members: forward +Wav2Vec2ForAudioFrameClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Wav2Vec2ForAudioFrameClassification + :members: forward + + +Wav2Vec2ForXVector +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.Wav2Vec2ForXVector + :members: forward + + Wav2Vec2ForPreTraining ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fc9f8a045b..2adf6c6a59 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -649,6 +649,8 @@ if is_torch_available(): "MODEL_WITH_LM_HEAD_MAPPING", "AutoModel", "AutoModelForAudioClassification", + "AutoModelForAudioFrameClassification", + "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", "AutoModelForImageClassification", @@ -1325,9 +1327,11 @@ if is_torch_available(): _import_structure["models.unispeech_sat"].extend( [ "UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST", + "UniSpeechSatForAudioFrameClassification", "UniSpeechSatForCTC", "UniSpeechSatForPreTraining", "UniSpeechSatForSequenceClassification", + "UniSpeechSatForXVector", "UniSpeechSatModel", "UniSpeechSatPreTrainedModel", ] @@ -1358,10 +1362,12 @@ if is_torch_available(): _import_structure["models.wav2vec2"].extend( [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Wav2Vec2ForAudioFrameClassification", "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", "Wav2Vec2ForPreTraining", "Wav2Vec2ForSequenceClassification", + "Wav2Vec2ForXVector", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -2603,6 +2609,8 @@ if TYPE_CHECKING: MODEL_WITH_LM_HEAD_MAPPING, AutoModel, AutoModelForAudioClassification, + AutoModelForAudioFrameClassification, + AutoModelForAudioXVector, AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, @@ -3164,9 +3172,11 @@ if TYPE_CHECKING: ) from .models.unispeech_sat import ( UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST, + UniSpeechSatForAudioFrameClassification, UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatForSequenceClassification, + UniSpeechSatForXVector, UniSpeechSatModel, UniSpeechSatPreTrainedModel, ) @@ -3191,10 +3201,12 @@ if TYPE_CHECKING: ) from .models.wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, + Wav2Vec2ForAudioFrameClassification, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index adbacd8aed..b8dcd70f1e 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1117,6 +1117,54 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r""" """ +PT_SPEECH_FRAME_CLASS_SAMPLE = r""" + Example:: + + >>> from transformers import {processor_class}, {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt") + >>> logits = model(**inputs).logits + >>> probabilities = torch.sigmoid(logits[0]) + >>> # labels is a one-hot array of shape (num_frames, num_speakers) + >>> labels = (probabilities > 0.5).long() +""" + + +PT_SPEECH_XVECTOR_SAMPLE = r""" + Example:: + + >>> from transformers import {processor_class}, {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}') + >>> model = {model_class}.from_pretrained('{checkpoint}') + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt") + >>> embeddings = model(**inputs).embeddings + >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu() + + >>> # the resulting embeddings can be used for cosine similarity-based retrieval + >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1) + >>> similarity = cosine_sim(embeddings[0], embeddings[1]) + >>> threshold = 0.7 # the optimal threshold is dataset-dependent + >>> if similarity < threshold: + ... print("Speakers are not the same!") +""" + PT_SAMPLE_DOCSTRINGS = { "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE, "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE, @@ -1128,6 +1176,8 @@ PT_SAMPLE_DOCSTRINGS = { "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE, "CTC": PT_SPEECH_CTC_SAMPLE, "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE, + "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE, + "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE, } @@ -1419,6 +1469,10 @@ def add_code_sample_docstrings( code_sample = sample_docstrings["LMHead"] elif "CTC" in model_class: code_sample = sample_docstrings["CTC"] + elif "AudioFrameClassification" in model_class: + code_sample = sample_docstrings["AudioFrameClassification"] + elif "XVector" in model_class and modality == "audio": + code_sample = sample_docstrings["AudioXVector"] elif "Model" in model_class and modality == "audio": code_sample = sample_docstrings["SpeechBaseModel"] elif "Model" in model_class or "Encoder" in model_class: diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index bd6e3a369b..60b2d725b0 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -53,6 +53,8 @@ if is_torch_available(): "MODEL_WITH_LM_HEAD_MAPPING", "AutoModel", "AutoModelForAudioClassification", + "AutoModelForAudioFrameClassification", + "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", "AutoModelForImageClassification", @@ -161,6 +163,8 @@ if TYPE_CHECKING: MODEL_WITH_LM_HEAD_MAPPING, AutoModel, AutoModelForAudioClassification, + AutoModelForAudioFrameClassification, + AutoModelForAudioXVector, AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 7fca274e3e..cbc340c11d 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -538,6 +538,22 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict( + [ + # Model for Audio Classification mapping + ("wav2vec2", "Wav2Vec2ForAudioFrameClassification"), + ("unispeech-sat", "UniSpeechSatForAudioFrameClassification"), + ] +) + +MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict( + [ + # Model for Audio Classification mapping + ("wav2vec2", "Wav2Vec2ForXVector"), + ("unispeech-sat", "UniSpeechSatForXVector"), + ] +) + MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) @@ -578,6 +594,10 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping( ) MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES) MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) +MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES +) +MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES) class AutoModel(_BaseAutoModelClass): @@ -726,6 +746,22 @@ AutoModelForSpeechSeq2Seq = auto_class_update( ) +class AutoModelForAudioFrameClassification(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING + + +AutoModelForAudioFrameClassification = auto_class_update( + AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" +) + + +class AutoModelForAudioXVector(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING + + +AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") + + class AutoModelWithLMHead(_AutoModelWithLMHead): @classmethod def from_config(cls, config): diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 6d2affd2df..a26066f2f5 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -42,9 +42,9 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "HubertConfig" _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" -_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _HIDDEN_STATES_START_POSITION = 1 @@ -1182,7 +1182,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel): @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( - processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, + processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index c028103c7f..5cecf18f2e 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -38,9 +38,9 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SEWConfig" _CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k" -_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _HIDDEN_STATES_START_POSITION = 1 @@ -1067,7 +1067,7 @@ class SEWForSequenceClassification(SEWPreTrainedModel): @add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING) @add_code_sample_docstrings( - processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, + processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 677c1384f7..e4e1bb0f0b 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -39,9 +39,9 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "SEWDConfig" _CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k" -_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _HIDDEN_STATES_START_POSITION = 1 @@ -1598,7 +1598,7 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel): @add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING) @add_code_sample_docstrings( - processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, + processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index ef6ebb6714..cfacf721b0 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -44,9 +44,9 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "UniSpeechConfig" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv" -_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _HIDDEN_STATES_START_POSITION = 2 @@ -1481,7 +1481,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel): @add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING) @add_code_sample_docstrings( - processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, + processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, diff --git a/src/transformers/models/unispeech_sat/__init__.py b/src/transformers/models/unispeech_sat/__init__.py index 6176d8efa3..a6479e962d 100644 --- a/src/transformers/models/unispeech_sat/__init__.py +++ b/src/transformers/models/unispeech_sat/__init__.py @@ -27,9 +27,11 @@ _import_structure = { if is_torch_available(): _import_structure["modeling_unispeech_sat"] = [ "UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST", + "UniSpeechSatForAudioFrameClassification", "UniSpeechSatForCTC", "UniSpeechSatForPreTraining", "UniSpeechSatForSequenceClassification", + "UniSpeechSatForXVector", "UniSpeechSatModel", "UniSpeechSatPreTrainedModel", ] @@ -40,9 +42,11 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_unispeech_sat import ( UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST, + UniSpeechSatForAudioFrameClassification, UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatForSequenceClassification, + UniSpeechSatForXVector, UniSpeechSatModel, UniSpeechSatPreTrainedModel, ) diff --git a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py index ecf8b01f1c..d9c2a00fce 100644 --- a/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/configuration_unispeech_sat.py @@ -153,6 +153,17 @@ class UniSpeechSatConfig(PretrainedConfig): instance of :class:`~transformers.UniSpeechSatForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN` + module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers. + tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the + `XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`. + tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the + `XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`. + xvector_output_dim (:obj:`int`, `optional`, defaults to 512): + Dimensionality of the `XVector` embedding vectors. Example:: @@ -213,6 +224,10 @@ class UniSpeechSatConfig(PretrainedConfig): ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -246,7 +261,6 @@ class UniSpeechSatConfig(PretrainedConfig): self.num_clusters = num_clusters self.do_stable_layer_norm = do_stable_layer_norm self.use_weighted_layer_sum = use_weighted_layer_sum - self.classifier_proj_size = classifier_proj_size if ( (len(self.conv_stride) != self.num_feat_extract_layers) @@ -282,3 +296,12 @@ class UniSpeechSatConfig(PretrainedConfig): # ctc loss self.ctc_loss_reduction = ctc_loss_reduction self.ctc_zero_infinity = ctc_zero_infinity + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim diff --git a/src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py new file mode 100644 index 0000000000..56c9d52e18 --- /dev/null +++ b/src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + + +import argparse + +import torch + +from transformers import ( + UniSpeechSatConfig, + UniSpeechSatForAudioFrameClassification, + UniSpeechSatForSequenceClassification, + UniSpeechSatForXVector, + Wav2Vec2FeatureExtractor, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def convert_classification(base_model_name, hf_config, downstream_dict): + model = UniSpeechSatForSequenceClassification.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["projector.weight"] + model.projector.bias.data = downstream_dict["projector.bias"] + model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + return model + + +def convert_diarization(base_model_name, hf_config, downstream_dict): + model = UniSpeechSatForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config) + model.classifier.weight.data = downstream_dict["model.linear.weight"] + model.classifier.bias.data = downstream_dict["model.linear.bias"] + return model + + +def convert_xvector(base_model_name, hf_config, downstream_dict): + model = UniSpeechSatForXVector.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["connector.weight"] + model.projector.bias.data = downstream_dict["connector.bias"] + for i, kernel_size in enumerate(hf_config.tdnn_kernel): + model.tdnn[i].kernel.weight.data = downstream_dict[ + f"model.framelevel_feature_extractor.module.{i}.kernel.weight" + ] + model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"] + + model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"] + model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"] + model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"] + model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"] + model.objective.weight.data = downstream_dict["objective.W"] + return model + + +@torch.no_grad() +def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): + """ + Copy/paste/tweak model's weights to transformers design. + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + downstream_dict = checkpoint["Downstream"] + + hf_config = UniSpeechSatConfig.from_pretrained(config_path) + hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + base_model_name, return_attention_mask=True, do_normalize=False + ) + + arch = hf_config.architectures[0] + if arch.endswith("ForSequenceClassification"): + hf_model = convert_classification(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForAudioFrameClassification"): + hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForXVector"): + hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) + else: + raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}") + + if hf_config.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + + hf_feature_extractor.save_pretrained(model_dump_path) + hf_model.save_pretrained(model_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model." + ) + parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.") + parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.") + args = parser.parse_args() + convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index cae266eef7..7b7009b329 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -33,7 +33,7 @@ from ...file_utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_unispeech_sat import UniSpeechSatConfig @@ -45,9 +45,11 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "UniSpeechSatConfig" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor" _CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus" -_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" +_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd" +_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv" _HIDDEN_STATES_START_POSITION = 2 @@ -123,6 +125,38 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of :class:`~transformers.Wav2Vec2ForXVector`. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + embeddings: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], @@ -1472,7 +1506,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) @add_code_sample_docstrings( - processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, + processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, @@ -1538,3 +1572,285 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT +class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.unispeech_sat = UniSpeechSatModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech_sat.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=None, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + hidden_states = hidden_states.unsqueeze(1) + hidden_states = nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + UNISPEECH_SAT_START_DOCSTRING, +) +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT +class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.unispeech_sat = UniSpeechSatModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.unispeech_sat.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.unispeech_sat.parameters(): + param.requires_grad = False + + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.unispeech_sat( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/wav2vec2/__init__.py b/src/transformers/models/wav2vec2/__init__.py index fb03d8a572..db27e5e0c4 100644 --- a/src/transformers/models/wav2vec2/__init__.py +++ b/src/transformers/models/wav2vec2/__init__.py @@ -31,10 +31,12 @@ _import_structure = { if is_torch_available(): _import_structure["modeling_wav2vec2"] = [ "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", + "Wav2Vec2ForAudioFrameClassification", "Wav2Vec2ForCTC", "Wav2Vec2ForMaskedLM", "Wav2Vec2ForPreTraining", "Wav2Vec2ForSequenceClassification", + "Wav2Vec2ForXVector", "Wav2Vec2Model", "Wav2Vec2PreTrainedModel", ] @@ -65,10 +67,12 @@ if TYPE_CHECKING: if is_torch_available(): from .modeling_wav2vec2 import ( WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, + Wav2Vec2ForAudioFrameClassification, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, Wav2Vec2Model, Wav2Vec2PreTrainedModel, ) diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index fcbfd1c41e..6f1d8c3234 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -80,10 +80,10 @@ class Wav2Vec2Config(PretrainedConfig): feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers. conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`): A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length - of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`. + of `conv_stride` defines the number of convolutional layers and has to match the length of `conv_dim`. conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`): A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The - length of `conv_kernel` defines the number of convolutional layers and has to match the the length of + length of `conv_kernel` defines the number of convolutional layers and has to match the length of `conv_dim`. conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether the 1D convolutional layers have a bias. @@ -153,6 +153,17 @@ class Wav2Vec2Config(PretrainedConfig): instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN` + module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers. + tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the + `XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`. + tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the + `XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`. + xvector_output_dim (:obj:`int`, `optional`, defaults to 512): + Dimensionality of the `XVector` embedding vectors. add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for warm-starting Wav2Vec2 for SpeechEncoderDecoder models. @@ -226,6 +237,10 @@ class Wav2Vec2Config(PretrainedConfig): ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -262,7 +277,6 @@ class Wav2Vec2Config(PretrainedConfig): self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm self.use_weighted_layer_sum = use_weighted_layer_sum - self.classifier_proj_size = classifier_proj_size if ( (len(self.conv_stride) != self.num_feat_extract_layers) @@ -305,3 +319,12 @@ class Wav2Vec2Config(PretrainedConfig): self.adapter_stride = adapter_stride self.num_adapter_layers = num_adapter_layers self.output_hidden_size = output_hidden_size or hidden_size + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py index bd7a7370cf..bcc9fd95a4 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py @@ -19,13 +19,52 @@ import argparse import torch -from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging +from transformers import ( + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + logging, +) logging.set_verbosity_info() logger = logging.get_logger(__name__) -SUPPORTED_MODELS = ["UtteranceLevel"] + +def convert_classification(base_model_name, hf_config, downstream_dict): + model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["projector.weight"] + model.projector.bias.data = downstream_dict["projector.bias"] + model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + return model + + +def convert_diarization(base_model_name, hf_config, downstream_dict): + model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config) + model.classifier.weight.data = downstream_dict["model.linear.weight"] + model.classifier.bias.data = downstream_dict["model.linear.bias"] + return model + + +def convert_xvector(base_model_name, hf_config, downstream_dict): + model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["connector.weight"] + model.projector.bias.data = downstream_dict["connector.bias"] + for i, kernel_size in enumerate(hf_config.tdnn_kernel): + model.tdnn[i].kernel.weight.data = downstream_dict[ + f"model.framelevel_feature_extractor.module.{i}.kernel.weight" + ] + model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"] + + model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"] + model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"] + model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"] + model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"] + model.objective.weight.data = downstream_dict["objective.W"] + return model @torch.no_grad() @@ -34,24 +73,26 @@ def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, mode Copy/paste/tweak model's weights to transformers design. """ checkpoint = torch.load(checkpoint_path, map_location="cpu") - if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS: - raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}") downstream_dict = checkpoint["Downstream"] - hf_congfig = Wav2Vec2Config.from_pretrained(config_path) - hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig) + hf_config = Wav2Vec2Config.from_pretrained(config_path) hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( base_model_name, return_attention_mask=True, do_normalize=False ) - if hf_congfig.use_weighted_layer_sum: - hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + arch = hf_config.architectures[0] + if arch.endswith("ForSequenceClassification"): + hf_model = convert_classification(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForAudioFrameClassification"): + hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForXVector"): + hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) + else: + raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}") - hf_model.projector.weight.data = downstream_dict["projector.weight"] - hf_model.projector.bias.data = downstream_dict["projector.bias"] - hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] - hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + if hf_config.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] hf_feature_extractor.save_pretrained(model_dump_path) hf_model.save_pretrained(model_dump_path) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 0f0e00fd46..046a04b2db 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -34,7 +34,13 @@ from ...file_utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_wav2vec2 import Wav2Vec2Config @@ -45,9 +51,11 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Wav2Vec2Config" _CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" _PROCESSOR_FOR_DOC = "Wav2Vec2Processor" +_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor" _SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks" -_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor" +_FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd" +_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv" _HIDDEN_STATES_START_POSITION = 2 @@ -93,7 +101,7 @@ class Wav2Vec2BaseModelOutput(ModelOutput): @dataclass class Wav2Vec2ForPreTrainingOutput(ModelOutput): """ - Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions. + Output type of :class:`~transformers.Wav2Vec2ForPreTraining`, with potential hidden states and attentions. Args: loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`): @@ -132,6 +140,38 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput): diversity_loss: Optional[torch.FloatTensor] = None +@dataclass +class XVectorOutput(ModelOutput): + """ + Output type of :class:`~transformers.Wav2Vec2ForXVector`. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Classification loss. + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`): + Classification hidden states before AMSoftmax. + embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`): + Utterance embeddings used for vector similarity-based retrieval. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + embeddings: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, @@ -1707,7 +1747,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @add_code_sample_docstrings( - processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC, + processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_SEQ_CLASS_CHECKPOINT, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, @@ -1773,3 +1813,281 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@add_start_docstrings( + """ + Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAV_2_VEC_2_START_DOCSTRING, +) +class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=None, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + hidden_states = hidden_states.unsqueeze(1) + hidden_states = nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAV_2_VEC_2_START_DOCSTRING, +) +class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameters + will not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + def forward( + self, + input_values, + attention_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + labels=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index efd82bf0f0..f2225f6cc8 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -422,6 +422,30 @@ class AutoModelForAudioClassification: requires_backends(self, ["torch"]) +class AutoModelForAudioFrameClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class AutoModelForAudioXVector: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def forward(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForCausalLM: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -4896,6 +4920,11 @@ class UniSpeechPreTrainedModel: UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = None +class UniSpeechSatForAudioFrameClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class UniSpeechSatForCTC: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -4918,6 +4947,11 @@ class UniSpeechSatForSequenceClassification: requires_backends(self, ["torch"]) +class UniSpeechSatForXVector: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class UniSpeechSatModel: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -5072,6 +5106,11 @@ class ViTPreTrainedModel: WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None +class Wav2Vec2ForAudioFrameClassification: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Wav2Vec2ForCTC: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -5106,6 +5145,11 @@ class Wav2Vec2ForSequenceClassification: requires_backends(self, ["torch"]) +class Wav2Vec2ForXVector: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Wav2Vec2Model: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) diff --git a/tests/test_modeling_unispeech_sat.py b/tests/test_modeling_unispeech_sat.py index 56c429d2f9..02dea6447e 100644 --- a/tests/test_modeling_unispeech_sat.py +++ b/tests/test_modeling_unispeech_sat.py @@ -33,9 +33,11 @@ if is_torch_available(): import torch from transformers import ( + UniSpeechSatForAudioFrameClassification, UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatForSequenceClassification, + UniSpeechSatForXVector, UniSpeechSatModel, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, @@ -70,6 +72,10 @@ class UniSpeechSatModelTester: mask_time_length=2, vocab_size=32, do_stable_layer_norm=False, + tdnn_dim=(32, 32), + tdnn_kernel=(3, 3), + tdnn_dilation=(1, 1), + xvector_output_dim=32, scope=None, ): self.parent = parent @@ -97,6 +103,10 @@ class UniSpeechSatModelTester: self.do_stable_layer_norm = do_stable_layer_norm self.mask_time_prob = mask_time_prob self.mask_time_length = mask_time_length + self.tdnn_dim = tdnn_dim + self.tdnn_kernel = tdnn_kernel + self.tdnn_dilation = tdnn_dilation + self.xvector_output_dim = xvector_output_dim self.scope = scope output_seq_length = self.seq_length @@ -135,6 +145,10 @@ class UniSpeechSatModelTester: hidden_act=self.hidden_act, initializer_range=self.initializer_range, vocab_size=self.vocab_size, + tdnn_dim=self.tdnn_dim, + tdnn_kernel=self.tdnn_kernel, + tdnn_dilation=self.tdnn_dilation, + xvector_output_dim=self.xvector_output_dim, ) def create_and_check_model(self, config, input_values, attention_mask): @@ -277,6 +291,30 @@ class UniSpeechSatModelTester: loss.backward() + def check_xvector_training(self, config, *args): + config.ctc_zero_infinity = True + model = UniSpeechSatForXVector(config=config) + model.to(torch_device) + model.train() + + # freeze everything but the classification head + model.freeze_base_model() + + # use a longer sequence length to account for TDNN temporal downsampling + input_values = floats_tensor([self.batch_size, self.seq_length * 2], self.vocab_size) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + def check_labels_out_of_vocab(self, config, input_values, *args): model = UniSpeechSatForCTC(config) model.to(torch_device) @@ -300,7 +338,14 @@ class UniSpeechSatModelTester: @require_torch class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification) + ( + UniSpeechSatForCTC, + UniSpeechSatForPreTraining, + UniSpeechSatModel, + UniSpeechSatForSequenceClassification, + UniSpeechSatForAudioFrameClassification, + UniSpeechSatForXVector, + ) if is_torch_available() else () ) @@ -335,6 +380,10 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_xvector_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_xvector_training(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs) @@ -417,6 +466,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase): "feature_projection.projection.weight", "feature_projection.projection.bias", "label_embeddings_concat", + "objective.weight", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -623,6 +673,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase): "feature_projection.projection.weight", "feature_projection.projection.bias", "label_embeddings_concat", + "objective.weight", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -811,3 +862,56 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase): # fmt: on self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3)) + + def test_inference_diarization(self): + model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to( + torch_device + ) + processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd") + input_data = self._load_superb("sd", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + # labels is a one-hot array of shape (num_frames, num_speakers) + labels = (outputs.logits > 0).long() + + # s3prl logits for the same batch + expected_logits = torch.tensor( + [ + [[-5.6119, -5.5845], [-3.7772, -5.4824], [-3.6914, -5.1619], [-4.7560, -5.0496]], + [[-6.3785, -4.8365], [-5.5863, -5.4149], [-5.5639, -4.8469], [-6.1511, -4.0052]], + [[-6.0355, -3.7414], [-5.5968, -4.8061], [-5.4620, -4.7310], [-5.5864, -4.6078]], + [[-5.9493, -4.8963], [-4.4050, -5.4476], [-4.1755, -5.1395], [-4.0272, -4.3705]], + ], + device=torch_device, + ) + self.assertEqual(labels[0, :, 0].sum(), 270) + self.assertEqual(labels[0, :, 1].sum(), 647) + self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3)) + + def test_inference_speaker_verification(self): + model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv") + input_data = self._load_superb("si", 4) + + inputs = processor(input_data["speech"], return_tensors="pt", padding=True) + labels = torch.tensor([5, 1, 1, 3], device=torch_device).T + + with torch.no_grad(): + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + outputs = model(input_values, attention_mask=attention_mask, labels=labels) + embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1) + + cosine_sim = torch.nn.CosineSimilarity(dim=-1) + # id10002 vs id10002 + self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9671, 3) + # id10006 vs id10002 + self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.4941, 3) + # id10002 vs id10004 + self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.5616, 3) + + self.assertAlmostEqual(outputs.loss.item(), 18.5925, 3) diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index c3a0271bd0..182d8ee1a2 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -44,10 +44,12 @@ if is_torch_available(): from transformers import ( Wav2Vec2FeatureExtractor, + Wav2Vec2ForAudioFrameClassification, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2ForPreTraining, Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, Wav2Vec2Model, Wav2Vec2Processor, ) @@ -96,6 +98,10 @@ class Wav2Vec2ModelTester: do_stable_layer_norm=False, num_adapter_layers=1, adapter_stride=2, + tdnn_dim=(32, 32), + tdnn_kernel=(5, 3), + tdnn_dilation=(1, 2), + xvector_output_dim=32, scope=None, ): self.parent = parent @@ -126,6 +132,10 @@ class Wav2Vec2ModelTester: self.mask_time_prob = mask_time_prob self.mask_time_length = mask_time_length self.scope = scope + self.tdnn_dim = tdnn_dim + self.tdnn_kernel = tdnn_kernel + self.tdnn_dilation = tdnn_dilation + self.xvector_output_dim = xvector_output_dim output_seq_length = self.seq_length for kernel, stride in zip(self.conv_kernel, self.conv_stride): @@ -168,6 +178,10 @@ class Wav2Vec2ModelTester: vocab_size=self.vocab_size, num_adapter_layers=self.num_adapter_layers, adapter_stride=self.adapter_stride, + tdnn_dim=self.tdnn_dim, + tdnn_kernel=self.tdnn_kernel, + tdnn_dilation=self.tdnn_dilation, + xvector_output_dim=self.xvector_output_dim, ) def create_and_check_model(self, config, input_values, attention_mask): @@ -332,6 +346,29 @@ class Wav2Vec2ModelTester: loss.backward() + def check_xvector_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = Wav2Vec2ForXVector(config=config) + model.to(torch_device) + model.train() + + # freeze everything but the classification head + model.freeze_base_model() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label)) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + def check_labels_out_of_vocab(self, config, input_values, *args): model = Wav2Vec2ForCTC(config) model.to(torch_device) @@ -398,6 +435,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_xvector_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_xvector_training(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs) @@ -489,6 +530,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): "project_q.bias", "feature_projection.projection.weight", "feature_projection.projection.bias", + "objective.weight", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -573,7 +615,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( - (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining) + ( + Wav2Vec2ForCTC, + Wav2Vec2Model, + Wav2Vec2ForMaskedLM, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForPreTraining, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForXVector, + ) if is_torch_available() else () ) @@ -622,6 +672,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_seq_classifier_training(*config_and_inputs) + def test_xvector_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_xvector_training(*config_and_inputs) + def test_labels_out_of_vocab(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_labels_out_of_vocab(*config_and_inputs) @@ -703,6 +757,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): "project_q.bias", "feature_projection.projection.weight", "feature_projection.projection.bias", + "objective.weight", ] if param.requires_grad: if any([x in name for x in uniform_init_parms]): @@ -1369,3 +1424,54 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase): transcription = processor.batch_decode(logits.cpu().numpy()).text self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero") + + def test_inference_diarization(self): + model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd") + input_data = self._load_superb("sd", 4) + inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000) + + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + with torch.no_grad(): + outputs = model(input_values, attention_mask=attention_mask) + # labels is a one-hot array of shape (num_frames, num_speakers) + labels = (outputs.logits > 0).long() + + # s3prl logits for the same batch + expected_logits = torch.tensor( + [ + [[-5.2807, -5.1272], [-5.4059, -4.7757], [-5.2764, -4.9621], [-5.0117, -4.5851]], + [[-1.7643, -0.5462], [-1.7369, -0.2649], [-1.5066, -0.6200], [-4.5703, -2.4863]], + [[-0.8656, -0.4783], [-0.8899, -0.3289], [-0.9267, -0.5781], [-0.7817, -0.4619]], + [[-4.8625, -2.5316], [-5.2339, -2.2155], [-4.9835, -2.0344], [-4.4727, -1.8421]], + ], + device=torch_device, + ) + self.assertEqual(labels[0, :, 0].sum(), 555) + self.assertEqual(labels[0, :, 1].sum(), 299) + self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3)) + + def test_inference_speaker_verification(self): + model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device) + processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sv") + input_data = self._load_superb("si", 4) + + inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000) + labels = torch.tensor([5, 1, 1, 3], device=torch_device).T + + with torch.no_grad(): + input_values = inputs.input_values.to(torch_device) + attention_mask = inputs.attention_mask.to(torch_device) + outputs = model(input_values, attention_mask=attention_mask, labels=labels) + embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1).cpu() + + cosine_sim = torch.nn.CosineSimilarity(dim=-1) + # id10002 vs id10002 + self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).numpy(), 0.9758, 3) + # id10006 vs id10002 + self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).numpy(), 0.7579, 3) + # id10002 vs id10004 + self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).numpy(), 0.7594, 3) + + self.assertAlmostEqual(outputs.loss.item(), 17.7963, 3) diff --git a/utils/update_metadata.py b/utils/update_metadata.py index 78123abadb..9d209d41eb 100644 --- a/utils/update_metadata.py +++ b/utils/update_metadata.py @@ -74,6 +74,12 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [ "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES", "AutoModelForNextSentencePrediction", ), + ( + "audio-frame-classification", + "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES", + "AutoModelForAudioFrameClassification", + ), + ("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"), ]