Add Speaker Diarization and Verification heads (#14723)
* Models * Squashed commit of the following: commit 72278e1e931a16d0879acc77f65762f3364833d0 Author: anton-l <aglozhkov@gmail.com> Date: Fri Dec 10 21:45:08 2021 +0300 * Add unispeech heads * Add sd/sv automodels * Docs cleanup * Fix docstrings * rename xvector classes * examples * Tests cleanup * Style * Better checkpoints for tests * leftover docs * apply review suggestions * Style + init tests * Update unispeech-sat tdnn downsampling
This commit is contained in:
@@ -181,6 +181,13 @@ AutoModelForAudioClassification
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForAudioFrameClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForAudioFrameClassification
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForCTC
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -195,6 +202,13 @@ AutoModelForSpeechSeq2Seq
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForAudioXVector
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AutoModelForAudioXVector
|
||||
:members:
|
||||
|
||||
|
||||
AutoModelForObjectDetection
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -85,6 +85,20 @@ UniSpeechSatForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
UniSpeechSatForAudioFrameClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.UniSpeechSatForAudioFrameClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
UniSpeechSatForXVector
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.UniSpeechSatForXVector
|
||||
:members: forward
|
||||
|
||||
|
||||
UniSpeechSatForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -114,6 +114,20 @@ Wav2Vec2ForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForAudioFrameClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2ForAudioFrameClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForXVector
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2ForXVector
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForPreTraining
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -649,6 +649,8 @@ if is_torch_available():
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
"AutoModelForAudioXVector",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForImageClassification",
|
||||
@@ -1325,9 +1327,11 @@ if is_torch_available():
|
||||
_import_structure["models.unispeech_sat"].extend(
|
||||
[
|
||||
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"UniSpeechSatForAudioFrameClassification",
|
||||
"UniSpeechSatForCTC",
|
||||
"UniSpeechSatForPreTraining",
|
||||
"UniSpeechSatForSequenceClassification",
|
||||
"UniSpeechSatForXVector",
|
||||
"UniSpeechSatModel",
|
||||
"UniSpeechSatPreTrainedModel",
|
||||
]
|
||||
@@ -1358,10 +1362,12 @@ if is_torch_available():
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForAudioFrameClassification",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"Wav2Vec2ForXVector",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@@ -2603,6 +2609,8 @@ if TYPE_CHECKING:
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForAudioFrameClassification,
|
||||
AutoModelForAudioXVector,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
@@ -3164,9 +3172,11 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.unispeech_sat import (
|
||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForXVector,
|
||||
UniSpeechSatModel,
|
||||
UniSpeechSatPreTrainedModel,
|
||||
)
|
||||
@@ -3191,10 +3201,12 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -1117,6 +1117,54 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
|
||||
"""
|
||||
|
||||
|
||||
PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {processor_class}, {model_class}
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> logits = model(**inputs).logits
|
||||
>>> probabilities = torch.sigmoid(logits[0])
|
||||
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
>>> labels = (probabilities > 0.5).long()
|
||||
"""
|
||||
|
||||
|
||||
PT_SPEECH_XVECTOR_SAMPLE = r"""
|
||||
Example::
|
||||
|
||||
>>> from transformers import {processor_class}, {model_class}
|
||||
>>> from datasets import load_dataset
|
||||
>>> import torch
|
||||
|
||||
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||
|
||||
>>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}')
|
||||
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||
|
||||
>>> # audio file is decoded on the fly
|
||||
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt")
|
||||
>>> embeddings = model(**inputs).embeddings
|
||||
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
||||
|
||||
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
|
||||
>>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
>>> similarity = cosine_sim(embeddings[0], embeddings[1])
|
||||
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
|
||||
>>> if similarity < threshold:
|
||||
... print("Speakers are not the same!")
|
||||
"""
|
||||
|
||||
PT_SAMPLE_DOCSTRINGS = {
|
||||
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
|
||||
@@ -1128,6 +1176,8 @@ PT_SAMPLE_DOCSTRINGS = {
|
||||
"SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
|
||||
"CTC": PT_SPEECH_CTC_SAMPLE,
|
||||
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
|
||||
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
|
||||
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
|
||||
}
|
||||
|
||||
|
||||
@@ -1419,6 +1469,10 @@ def add_code_sample_docstrings(
|
||||
code_sample = sample_docstrings["LMHead"]
|
||||
elif "CTC" in model_class:
|
||||
code_sample = sample_docstrings["CTC"]
|
||||
elif "AudioFrameClassification" in model_class:
|
||||
code_sample = sample_docstrings["AudioFrameClassification"]
|
||||
elif "XVector" in model_class and modality == "audio":
|
||||
code_sample = sample_docstrings["AudioXVector"]
|
||||
elif "Model" in model_class and modality == "audio":
|
||||
code_sample = sample_docstrings["SpeechBaseModel"]
|
||||
elif "Model" in model_class or "Encoder" in model_class:
|
||||
|
||||
@@ -53,6 +53,8 @@ if is_torch_available():
|
||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||
"AutoModel",
|
||||
"AutoModelForAudioClassification",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
"AutoModelForAudioXVector",
|
||||
"AutoModelForCausalLM",
|
||||
"AutoModelForCTC",
|
||||
"AutoModelForImageClassification",
|
||||
@@ -161,6 +163,8 @@ if TYPE_CHECKING:
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
AutoModel,
|
||||
AutoModelForAudioClassification,
|
||||
AutoModelForAudioFrameClassification,
|
||||
AutoModelForAudioXVector,
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForCTC,
|
||||
AutoModelForImageClassification,
|
||||
|
||||
@@ -538,6 +538,22 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
|
||||
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Audio Classification mapping
|
||||
("wav2vec2", "Wav2Vec2ForXVector"),
|
||||
("unispeech-sat", "UniSpeechSatForXVector"),
|
||||
]
|
||||
)
|
||||
|
||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||
@@ -578,6 +594,10 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
)
|
||||
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
|
||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
|
||||
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
|
||||
|
||||
|
||||
class AutoModel(_BaseAutoModelClass):
|
||||
@@ -726,6 +746,22 @@ AutoModelForSpeechSeq2Seq = auto_class_update(
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
|
||||
|
||||
|
||||
AutoModelForAudioFrameClassification = auto_class_update(
|
||||
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
|
||||
)
|
||||
|
||||
|
||||
class AutoModelForAudioXVector(_BaseAutoModelClass):
|
||||
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
||||
|
||||
|
||||
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
|
||||
|
||||
|
||||
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
|
||||
@@ -42,9 +42,9 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "HubertConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
@@ -1182,7 +1182,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
|
||||
@@ -38,9 +38,9 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "SEWConfig"
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
@@ -1067,7 +1067,7 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
|
||||
@@ -39,9 +39,9 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 1
|
||||
|
||||
@@ -1598,7 +1598,7 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
|
||||
@@ -44,9 +44,9 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
@@ -1481,7 +1481,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
|
||||
@@ -27,9 +27,11 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_unispeech_sat"] = [
|
||||
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"UniSpeechSatForAudioFrameClassification",
|
||||
"UniSpeechSatForCTC",
|
||||
"UniSpeechSatForPreTraining",
|
||||
"UniSpeechSatForSequenceClassification",
|
||||
"UniSpeechSatForXVector",
|
||||
"UniSpeechSatModel",
|
||||
"UniSpeechSatPreTrainedModel",
|
||||
]
|
||||
@@ -40,9 +42,11 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_unispeech_sat import (
|
||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForXVector,
|
||||
UniSpeechSatModel,
|
||||
UniSpeechSatPreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -153,6 +153,17 @@ class UniSpeechSatConfig(PretrainedConfig):
|
||||
instance of :class:`~transformers.UniSpeechSatForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
|
||||
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
|
||||
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
|
||||
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
|
||||
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
|
||||
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
|
||||
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
|
||||
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
|
||||
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
|
||||
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimensionality of the `XVector` embedding vectors.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -213,6 +224,10 @@ class UniSpeechSatConfig(PretrainedConfig):
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
tdnn_dim=(512, 512, 512, 512, 1500),
|
||||
tdnn_kernel=(5, 3, 3, 1, 1),
|
||||
tdnn_dilation=(1, 2, 3, 1, 1),
|
||||
xvector_output_dim=512,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
@@ -246,7 +261,6 @@ class UniSpeechSatConfig(PretrainedConfig):
|
||||
self.num_clusters = num_clusters
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
@@ -282,3 +296,12 @@ class UniSpeechSatConfig(PretrainedConfig):
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
# XVector-specific parameters. Feel free to ignore for other classes.
|
||||
self.tdnn_dim = list(tdnn_dim)
|
||||
self.tdnn_kernel = list(tdnn_kernel)
|
||||
self.tdnn_dilation = list(tdnn_dilation)
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Hubert checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
UniSpeechSatConfig,
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForXVector,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def convert_classification(base_model_name, hf_config, downstream_dict):
|
||||
model = UniSpeechSatForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
|
||||
model.projector.weight.data = downstream_dict["projector.weight"]
|
||||
model.projector.bias.data = downstream_dict["projector.bias"]
|
||||
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||
return model
|
||||
|
||||
|
||||
def convert_diarization(base_model_name, hf_config, downstream_dict):
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
|
||||
model.classifier.weight.data = downstream_dict["model.linear.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.linear.bias"]
|
||||
return model
|
||||
|
||||
|
||||
def convert_xvector(base_model_name, hf_config, downstream_dict):
|
||||
model = UniSpeechSatForXVector.from_pretrained(base_model_name, config=hf_config)
|
||||
model.projector.weight.data = downstream_dict["connector.weight"]
|
||||
model.projector.bias.data = downstream_dict["connector.bias"]
|
||||
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
|
||||
model.tdnn[i].kernel.weight.data = downstream_dict[
|
||||
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
|
||||
]
|
||||
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
|
||||
|
||||
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
|
||||
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
|
||||
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
|
||||
model.objective.weight.data = downstream_dict["objective.W"]
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
downstream_dict = checkpoint["Downstream"]
|
||||
|
||||
hf_config = UniSpeechSatConfig.from_pretrained(config_path)
|
||||
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_model_name, return_attention_mask=True, do_normalize=False
|
||||
)
|
||||
|
||||
arch = hf_config.architectures[0]
|
||||
if arch.endswith("ForSequenceClassification"):
|
||||
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
|
||||
elif arch.endswith("ForAudioFrameClassification"):
|
||||
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
|
||||
elif arch.endswith("ForXVector"):
|
||||
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
|
||||
else:
|
||||
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
|
||||
|
||||
if hf_config.use_weighted_layer_sum:
|
||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||
|
||||
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||
hf_model.save_pretrained(model_dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
|
||||
)
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
|
||||
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
|
||||
args = parser.parse_args()
|
||||
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
|
||||
@@ -33,7 +33,7 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_unispeech_sat import UniSpeechSatConfig
|
||||
@@ -45,9 +45,11 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
|
||||
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
@@ -123,6 +125,38 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class XVectorOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||
Classification hidden states before AMSoftmax.
|
||||
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||
Utterance embeddings used for vector similarity-based retrieval.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
embeddings: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
@@ -1472,7 +1506,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
@@ -1538,3 +1572,285 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||
""",
|
||||
UNISPEECH_SAT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||
class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.unispeech_sat = UniSpeechSatModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.unispeech_sat.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.unispeech_sat.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_FRAME_CLASS_CHECKPOINT,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.unispeech_sat(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
||||
class AMSoftmaxLoss(nn.Module):
|
||||
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
||||
super(AMSoftmaxLoss, self).__init__()
|
||||
self.scale = scale
|
||||
self.margin = margin
|
||||
self.num_labels = num_labels
|
||||
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, hidden_states, labels):
|
||||
labels = labels.flatten()
|
||||
weight = nn.functional.normalize(self.weight, dim=0)
|
||||
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
||||
cos_theta = torch.mm(hidden_states, weight)
|
||||
psi = cos_theta - self.margin
|
||||
|
||||
onehot = nn.functional.one_hot(labels, self.num_labels)
|
||||
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
||||
loss = self.loss(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
||||
class TDNNLayer(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
||||
self.out_conv_dim = config.tdnn_dim[layer_id]
|
||||
self.kernel_size = config.tdnn_kernel[layer_id]
|
||||
self.dilation = config.tdnn_dilation[layer_id]
|
||||
|
||||
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
hidden_states = nn.functional.unfold(
|
||||
hidden_states,
|
||||
(self.kernel_size, self.in_conv_dim),
|
||||
stride=(1, self.in_conv_dim),
|
||||
dilation=(self.dilation, 1),
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.kernel(hidden_states)
|
||||
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
||||
""",
|
||||
UNISPEECH_SAT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||
class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.unispeech_sat = UniSpeechSatModel(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
||||
|
||||
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
||||
self.tdnn = nn.ModuleList(tdnn_layers)
|
||||
|
||||
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
||||
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
||||
|
||||
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.unispeech_sat.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.unispeech_sat.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
"""
|
||||
Computes the output length of the TDNN layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return (input_length - kernel_size) // stride + 1
|
||||
|
||||
for kernel_size in self.config.tdnn_kernel:
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
||||
|
||||
return input_lengths
|
||||
|
||||
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_XVECTOR_CHECKPOINT,
|
||||
output_type=XVectorOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.unispeech_sat(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
|
||||
for tdnn_layer in self.tdnn:
|
||||
hidden_states = tdnn_layer(hidden_states)
|
||||
|
||||
# Statistic Pooling
|
||||
if attention_mask is None:
|
||||
mean_features = hidden_states.mean(dim=1)
|
||||
std_features = hidden_states.std(dim=1)
|
||||
else:
|
||||
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
||||
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
||||
mean_features = []
|
||||
std_features = []
|
||||
for i, length in enumerate(tdnn_output_lengths):
|
||||
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
||||
std_features.append(hidden_states[i, :length].std(dim=0))
|
||||
mean_features = torch.stack(mean_features)
|
||||
std_features = torch.stack(std_features)
|
||||
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
||||
|
||||
output_embeddings = self.feature_extractor(statistic_pooling)
|
||||
logits = self.classifier(output_embeddings)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.objective(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return XVectorOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
embeddings=output_embeddings,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -31,10 +31,12 @@ _import_structure = {
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_wav2vec2"] = [
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForAudioFrameClassification",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForPreTraining",
|
||||
"Wav2Vec2ForSequenceClassification",
|
||||
"Wav2Vec2ForXVector",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@@ -65,10 +67,12 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
)
|
||||
|
||||
@@ -80,10 +80,10 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
|
||||
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
|
||||
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
|
||||
of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`.
|
||||
of `conv_stride` defines the number of convolutional layers and has to match the length of `conv_dim`.
|
||||
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
|
||||
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
|
||||
length of `conv_kernel` defines the number of convolutional layers and has to match the the length of
|
||||
length of `conv_kernel` defines the number of convolutional layers and has to match the length of
|
||||
`conv_dim`.
|
||||
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether the 1D convolutional layers have a bias.
|
||||
@@ -153,6 +153,17 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
|
||||
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
|
||||
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
|
||||
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
|
||||
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
|
||||
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
|
||||
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
|
||||
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
|
||||
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
|
||||
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
|
||||
Dimensionality of the `XVector` embedding vectors.
|
||||
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
||||
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
||||
@@ -226,6 +237,10 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
ctc_zero_infinity=False,
|
||||
use_weighted_layer_sum=False,
|
||||
classifier_proj_size=256,
|
||||
tdnn_dim=(512, 512, 512, 512, 1500),
|
||||
tdnn_kernel=(5, 3, 3, 1, 1),
|
||||
tdnn_dilation=(1, 2, 3, 1, 1),
|
||||
xvector_output_dim=512,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
@@ -262,7 +277,6 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
@@ -305,3 +319,12 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.adapter_stride = adapter_stride
|
||||
self.num_adapter_layers = num_adapter_layers
|
||||
self.output_hidden_size = output_hidden_size or hidden_size
|
||||
|
||||
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
||||
self.classifier_proj_size = classifier_proj_size
|
||||
|
||||
# XVector-specific parameters. Feel free to ignore for other classes.
|
||||
self.tdnn_dim = list(tdnn_dim)
|
||||
self.tdnn_kernel = list(tdnn_kernel)
|
||||
self.tdnn_dilation = list(tdnn_dilation)
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
@@ -19,13 +19,52 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging
|
||||
from transformers import (
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
SUPPORTED_MODELS = ["UtteranceLevel"]
|
||||
|
||||
def convert_classification(base_model_name, hf_config, downstream_dict):
|
||||
model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
|
||||
model.projector.weight.data = downstream_dict["projector.weight"]
|
||||
model.projector.bias.data = downstream_dict["projector.bias"]
|
||||
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||
return model
|
||||
|
||||
|
||||
def convert_diarization(base_model_name, hf_config, downstream_dict):
|
||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
|
||||
model.classifier.weight.data = downstream_dict["model.linear.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.linear.bias"]
|
||||
return model
|
||||
|
||||
|
||||
def convert_xvector(base_model_name, hf_config, downstream_dict):
|
||||
model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config)
|
||||
model.projector.weight.data = downstream_dict["connector.weight"]
|
||||
model.projector.bias.data = downstream_dict["connector.bias"]
|
||||
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
|
||||
model.tdnn[i].kernel.weight.data = downstream_dict[
|
||||
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
|
||||
]
|
||||
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
|
||||
|
||||
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
|
||||
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
|
||||
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
|
||||
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
|
||||
model.objective.weight.data = downstream_dict["objective.W"]
|
||||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -34,24 +73,26 @@ def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, mode
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
|
||||
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
|
||||
|
||||
downstream_dict = checkpoint["Downstream"]
|
||||
|
||||
hf_congfig = Wav2Vec2Config.from_pretrained(config_path)
|
||||
hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
|
||||
hf_config = Wav2Vec2Config.from_pretrained(config_path)
|
||||
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_model_name, return_attention_mask=True, do_normalize=False
|
||||
)
|
||||
|
||||
if hf_congfig.use_weighted_layer_sum:
|
||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||
arch = hf_config.architectures[0]
|
||||
if arch.endswith("ForSequenceClassification"):
|
||||
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
|
||||
elif arch.endswith("ForAudioFrameClassification"):
|
||||
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
|
||||
elif arch.endswith("ForXVector"):
|
||||
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
|
||||
else:
|
||||
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
|
||||
|
||||
hf_model.projector.weight.data = downstream_dict["projector.weight"]
|
||||
hf_model.projector.bias.data = downstream_dict["projector.bias"]
|
||||
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||
if hf_config.use_weighted_layer_sum:
|
||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||
|
||||
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||
hf_model.save_pretrained(model_dump_path)
|
||||
|
||||
@@ -34,7 +34,13 @@ from ...file_utils import (
|
||||
add_start_docstrings_to_model_forward,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutput,
|
||||
CausalLMOutput,
|
||||
MaskedLMOutput,
|
||||
SequenceClassifierOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
@@ -45,9 +51,11 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
|
||||
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||
_FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd"
|
||||
_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv"
|
||||
|
||||
_HIDDEN_STATES_START_POSITION = 2
|
||||
|
||||
@@ -93,7 +101,7 @@ class Wav2Vec2BaseModelOutput(ModelOutput):
|
||||
@dataclass
|
||||
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
|
||||
Output type of :class:`~transformers.Wav2Vec2ForPreTraining`, with potential hidden states and attentions.
|
||||
|
||||
Args:
|
||||
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||
@@ -132,6 +140,38 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||
diversity_loss: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class XVectorOutput(ModelOutput):
|
||||
"""
|
||||
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
|
||||
|
||||
Args:
|
||||
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||
Classification loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||
Classification hidden states before AMSoftmax.
|
||||
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||
Utterance embeddings used for vector similarity-based retrieval.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||
heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
embeddings: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
@@ -1707,7 +1747,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||
output_type=SequenceClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
@@ -1773,3 +1813,281 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||
""",
|
||||
WAV_2_VEC_2_START_DOCSTRING,
|
||||
)
|
||||
class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.wav2vec2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_FRAME_CLASS_CHECKPOINT,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
logits = self.classifier(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class AMSoftmaxLoss(nn.Module):
|
||||
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
||||
super(AMSoftmaxLoss, self).__init__()
|
||||
self.scale = scale
|
||||
self.margin = margin
|
||||
self.num_labels = num_labels
|
||||
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, hidden_states, labels):
|
||||
labels = labels.flatten()
|
||||
weight = nn.functional.normalize(self.weight, dim=0)
|
||||
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
||||
cos_theta = torch.mm(hidden_states, weight)
|
||||
psi = cos_theta - self.margin
|
||||
|
||||
onehot = nn.functional.one_hot(labels, self.num_labels)
|
||||
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
||||
loss = self.loss(logits, labels)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TDNNLayer(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
||||
self.out_conv_dim = config.tdnn_dim[layer_id]
|
||||
self.kernel_size = config.tdnn_kernel[layer_id]
|
||||
self.dilation = config.tdnn_dilation[layer_id]
|
||||
|
||||
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
hidden_states = nn.functional.unfold(
|
||||
hidden_states,
|
||||
(self.kernel_size, self.in_conv_dim),
|
||||
stride=(1, self.in_conv_dim),
|
||||
dilation=(self.dilation, 1),
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.kernel(hidden_states)
|
||||
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""
|
||||
Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
||||
""",
|
||||
WAV_2_VEC_2_START_DOCSTRING,
|
||||
)
|
||||
class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||
if config.use_weighted_layer_sum:
|
||||
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
||||
|
||||
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
||||
self.tdnn = nn.ModuleList(tdnn_layers)
|
||||
|
||||
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
||||
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
||||
|
||||
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
def freeze_base_model(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||
be updated during training. Only the classification head will be updated.
|
||||
"""
|
||||
for param in self.wav2vec2.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
"""
|
||||
Computes the output length of the TDNN layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return (input_length - kernel_size) // stride + 1
|
||||
|
||||
for kernel_size in self.config.tdnn_kernel:
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
||||
|
||||
return input_lengths
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||
checkpoint=_XVECTOR_CHECKPOINT,
|
||||
output_type=XVectorOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
modality="audio",
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
if self.config.use_weighted_layer_sum:
|
||||
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||
hidden_states = torch.stack(hidden_states, dim=1)
|
||||
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||
else:
|
||||
hidden_states = outputs[0]
|
||||
|
||||
hidden_states = self.projector(hidden_states)
|
||||
|
||||
for tdnn_layer in self.tdnn:
|
||||
hidden_states = tdnn_layer(hidden_states)
|
||||
|
||||
# Statistic Pooling
|
||||
if attention_mask is None:
|
||||
mean_features = hidden_states.mean(dim=1)
|
||||
std_features = hidden_states.std(dim=1)
|
||||
else:
|
||||
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
||||
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
||||
mean_features = []
|
||||
std_features = []
|
||||
for i, length in enumerate(tdnn_output_lengths):
|
||||
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
||||
std_features.append(hidden_states[i, :length].std(dim=0))
|
||||
mean_features = torch.stack(mean_features)
|
||||
std_features = torch.stack(std_features)
|
||||
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
||||
|
||||
output_embeddings = self.feature_extractor(statistic_pooling)
|
||||
logits = self.classifier(output_embeddings)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss = self.objective(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return XVectorOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
embeddings=output_embeddings,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
@@ -422,6 +422,30 @@ class AutoModelForAudioClassification:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForAudioFrameClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForAudioXVector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class AutoModelForCausalLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@@ -4896,6 +4920,11 @@ class UniSpeechPreTrainedModel:
|
||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class UniSpeechSatForAudioFrameClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UniSpeechSatForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@@ -4918,6 +4947,11 @@ class UniSpeechSatForSequenceClassification:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UniSpeechSatForXVector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class UniSpeechSatModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@@ -5072,6 +5106,11 @@ class ViTPreTrainedModel:
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class Wav2Vec2ForAudioFrameClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2ForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
@@ -5106,6 +5145,11 @@ class Wav2Vec2ForSequenceClassification:
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2ForXVector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
|
||||
class Wav2Vec2Model:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@@ -33,9 +33,11 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForXVector,
|
||||
UniSpeechSatModel,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
@@ -70,6 +72,10 @@ class UniSpeechSatModelTester:
|
||||
mask_time_length=2,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
tdnn_dim=(32, 32),
|
||||
tdnn_kernel=(3, 3),
|
||||
tdnn_dilation=(1, 1),
|
||||
xvector_output_dim=32,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -97,6 +103,10 @@ class UniSpeechSatModelTester:
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.tdnn_dim = tdnn_dim
|
||||
self.tdnn_kernel = tdnn_kernel
|
||||
self.tdnn_dilation = tdnn_dilation
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
@@ -135,6 +145,10 @@ class UniSpeechSatModelTester:
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
tdnn_dim=self.tdnn_dim,
|
||||
tdnn_kernel=self.tdnn_kernel,
|
||||
tdnn_dilation=self.tdnn_dilation,
|
||||
xvector_output_dim=self.xvector_output_dim,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@@ -277,6 +291,30 @@ class UniSpeechSatModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_xvector_training(self, config, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = UniSpeechSatForXVector(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
# use a longer sequence length to account for TDNN temporal downsampling
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length * 2], self.vocab_size)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = UniSpeechSatForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -300,7 +338,14 @@ class UniSpeechSatModelTester:
|
||||
@require_torch
|
||||
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
||||
(
|
||||
UniSpeechSatForCTC,
|
||||
UniSpeechSatForPreTraining,
|
||||
UniSpeechSatModel,
|
||||
UniSpeechSatForSequenceClassification,
|
||||
UniSpeechSatForAudioFrameClassification,
|
||||
UniSpeechSatForXVector,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -335,6 +380,10 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_xvector_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
@@ -417,6 +466,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -623,6 +673,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"label_embeddings_concat",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -811,3 +862,56 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
|
||||
torch_device
|
||||
)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
labels = (outputs.logits > 0).long()
|
||||
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-5.6119, -5.5845], [-3.7772, -5.4824], [-3.6914, -5.1619], [-4.7560, -5.0496]],
|
||||
[[-6.3785, -4.8365], [-5.5863, -5.4149], [-5.5639, -4.8469], [-6.1511, -4.0052]],
|
||||
[[-6.0355, -3.7414], [-5.5968, -4.8061], [-5.4620, -4.7310], [-5.5864, -4.6078]],
|
||||
[[-5.9493, -4.8963], [-4.4050, -5.4476], [-4.1755, -5.1395], [-4.0272, -4.3705]],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 270)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||
|
||||
with torch.no_grad():
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
|
||||
|
||||
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
# id10002 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9671, 3)
|
||||
# id10006 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.4941, 3)
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.5616, 3)
|
||||
|
||||
self.assertAlmostEqual(outputs.loss.item(), 18.5925, 3)
|
||||
|
||||
@@ -44,10 +44,12 @@ if is_torch_available():
|
||||
|
||||
from transformers import (
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2Processor,
|
||||
)
|
||||
@@ -96,6 +98,10 @@ class Wav2Vec2ModelTester:
|
||||
do_stable_layer_norm=False,
|
||||
num_adapter_layers=1,
|
||||
adapter_stride=2,
|
||||
tdnn_dim=(32, 32),
|
||||
tdnn_kernel=(5, 3),
|
||||
tdnn_dilation=(1, 2),
|
||||
xvector_output_dim=32,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -126,6 +132,10 @@ class Wav2Vec2ModelTester:
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.scope = scope
|
||||
self.tdnn_dim = tdnn_dim
|
||||
self.tdnn_kernel = tdnn_kernel
|
||||
self.tdnn_dilation = tdnn_dilation
|
||||
self.xvector_output_dim = xvector_output_dim
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||
@@ -168,6 +178,10 @@ class Wav2Vec2ModelTester:
|
||||
vocab_size=self.vocab_size,
|
||||
num_adapter_layers=self.num_adapter_layers,
|
||||
adapter_stride=self.adapter_stride,
|
||||
tdnn_dim=self.tdnn_dim,
|
||||
tdnn_kernel=self.tdnn_kernel,
|
||||
tdnn_dilation=self.tdnn_dilation,
|
||||
xvector_output_dim=self.xvector_output_dim,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@@ -332,6 +346,29 @@ class Wav2Vec2ModelTester:
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_xvector_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForXVector(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze everything but the classification head
|
||||
model.freeze_base_model()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForCTC(config)
|
||||
model.to(torch_device)
|
||||
@@ -398,6 +435,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_xvector_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
@@ -489,6 +530,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -573,7 +615,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
||||
(
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2ForSequenceClassification,
|
||||
Wav2Vec2ForPreTraining,
|
||||
Wav2Vec2ForAudioFrameClassification,
|
||||
Wav2Vec2ForXVector,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
@@ -622,6 +672,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||
|
||||
def test_xvector_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||
|
||||
def test_labels_out_of_vocab(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||
@@ -703,6 +757,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"project_q.bias",
|
||||
"feature_projection.projection.weight",
|
||||
"feature_projection.projection.bias",
|
||||
"objective.weight",
|
||||
]
|
||||
if param.requires_grad:
|
||||
if any([x in name for x in uniform_init_parms]):
|
||||
@@ -1369,3 +1424,54 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
transcription = processor.batch_decode(logits.cpu().numpy()).text
|
||||
|
||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||
|
||||
def test_inference_diarization(self):
|
||||
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
||||
input_data = self._load_superb("sd", 4)
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(input_values, attention_mask=attention_mask)
|
||||
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||
labels = (outputs.logits > 0).long()
|
||||
|
||||
# s3prl logits for the same batch
|
||||
expected_logits = torch.tensor(
|
||||
[
|
||||
[[-5.2807, -5.1272], [-5.4059, -4.7757], [-5.2764, -4.9621], [-5.0117, -4.5851]],
|
||||
[[-1.7643, -0.5462], [-1.7369, -0.2649], [-1.5066, -0.6200], [-4.5703, -2.4863]],
|
||||
[[-0.8656, -0.4783], [-0.8899, -0.3289], [-0.9267, -0.5781], [-0.7817, -0.4619]],
|
||||
[[-4.8625, -2.5316], [-5.2339, -2.2155], [-4.9835, -2.0344], [-4.4727, -1.8421]],
|
||||
],
|
||||
device=torch_device,
|
||||
)
|
||||
self.assertEqual(labels[0, :, 0].sum(), 555)
|
||||
self.assertEqual(labels[0, :, 1].sum(), 299)
|
||||
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||
|
||||
def test_inference_speaker_verification(self):
|
||||
model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device)
|
||||
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sv")
|
||||
input_data = self._load_superb("si", 4)
|
||||
|
||||
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||
|
||||
with torch.no_grad():
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1).cpu()
|
||||
|
||||
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||
# id10002 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).numpy(), 0.9758, 3)
|
||||
# id10006 vs id10002
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).numpy(), 0.7579, 3)
|
||||
# id10002 vs id10004
|
||||
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).numpy(), 0.7594, 3)
|
||||
|
||||
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 3)
|
||||
|
||||
@@ -74,6 +74,12 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
),
|
||||
(
|
||||
"audio-frame-classification",
|
||||
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
),
|
||||
("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user