Add Speaker Diarization and Verification heads (#14723)
* Models * Squashed commit of the following: commit 72278e1e931a16d0879acc77f65762f3364833d0 Author: anton-l <aglozhkov@gmail.com> Date: Fri Dec 10 21:45:08 2021 +0300 * Add unispeech heads * Add sd/sv automodels * Docs cleanup * Fix docstrings * rename xvector classes * examples * Tests cleanup * Style * Better checkpoints for tests * leftover docs * apply review suggestions * Style + init tests * Update unispeech-sat tdnn downsampling
This commit is contained in:
@@ -181,6 +181,13 @@ AutoModelForAudioClassification
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForAudioFrameClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.AutoModelForAudioFrameClassification
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
AutoModelForCTC
|
AutoModelForCTC
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@@ -195,6 +202,13 @@ AutoModelForSpeechSeq2Seq
|
|||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForAudioXVector
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.AutoModelForAudioXVector
|
||||||
|
:members:
|
||||||
|
|
||||||
|
|
||||||
AutoModelForObjectDetection
|
AutoModelForObjectDetection
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -85,6 +85,20 @@ UniSpeechSatForSequenceClassification
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
UniSpeechSatForAudioFrameClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.UniSpeechSatForAudioFrameClassification
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
UniSpeechSatForXVector
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.UniSpeechSatForXVector
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
UniSpeechSatForPreTraining
|
UniSpeechSatForPreTraining
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,20 @@ Wav2Vec2ForSequenceClassification
|
|||||||
:members: forward
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2ForAudioFrameClassification
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Wav2Vec2ForAudioFrameClassification
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
|
Wav2Vec2ForXVector
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
.. autoclass:: transformers.Wav2Vec2ForXVector
|
||||||
|
:members: forward
|
||||||
|
|
||||||
|
|
||||||
Wav2Vec2ForPreTraining
|
Wav2Vec2ForPreTraining
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|||||||
@@ -649,6 +649,8 @@ if is_torch_available():
|
|||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
"AutoModelForAudioClassification",
|
"AutoModelForAudioClassification",
|
||||||
|
"AutoModelForAudioFrameClassification",
|
||||||
|
"AutoModelForAudioXVector",
|
||||||
"AutoModelForCausalLM",
|
"AutoModelForCausalLM",
|
||||||
"AutoModelForCTC",
|
"AutoModelForCTC",
|
||||||
"AutoModelForImageClassification",
|
"AutoModelForImageClassification",
|
||||||
@@ -1325,9 +1327,11 @@ if is_torch_available():
|
|||||||
_import_structure["models.unispeech_sat"].extend(
|
_import_structure["models.unispeech_sat"].extend(
|
||||||
[
|
[
|
||||||
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"UniSpeechSatForAudioFrameClassification",
|
||||||
"UniSpeechSatForCTC",
|
"UniSpeechSatForCTC",
|
||||||
"UniSpeechSatForPreTraining",
|
"UniSpeechSatForPreTraining",
|
||||||
"UniSpeechSatForSequenceClassification",
|
"UniSpeechSatForSequenceClassification",
|
||||||
|
"UniSpeechSatForXVector",
|
||||||
"UniSpeechSatModel",
|
"UniSpeechSatModel",
|
||||||
"UniSpeechSatPreTrainedModel",
|
"UniSpeechSatPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -1358,10 +1362,12 @@ if is_torch_available():
|
|||||||
_import_structure["models.wav2vec2"].extend(
|
_import_structure["models.wav2vec2"].extend(
|
||||||
[
|
[
|
||||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Wav2Vec2ForAudioFrameClassification",
|
||||||
"Wav2Vec2ForCTC",
|
"Wav2Vec2ForCTC",
|
||||||
"Wav2Vec2ForMaskedLM",
|
"Wav2Vec2ForMaskedLM",
|
||||||
"Wav2Vec2ForPreTraining",
|
"Wav2Vec2ForPreTraining",
|
||||||
"Wav2Vec2ForSequenceClassification",
|
"Wav2Vec2ForSequenceClassification",
|
||||||
|
"Wav2Vec2ForXVector",
|
||||||
"Wav2Vec2Model",
|
"Wav2Vec2Model",
|
||||||
"Wav2Vec2PreTrainedModel",
|
"Wav2Vec2PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -2603,6 +2609,8 @@ if TYPE_CHECKING:
|
|||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForAudioClassification,
|
AutoModelForAudioClassification,
|
||||||
|
AutoModelForAudioFrameClassification,
|
||||||
|
AutoModelForAudioXVector,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForCTC,
|
AutoModelForCTC,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
@@ -3164,9 +3172,11 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.unispeech_sat import (
|
from .models.unispeech_sat import (
|
||||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
UniSpeechSatForAudioFrameClassification,
|
||||||
UniSpeechSatForCTC,
|
UniSpeechSatForCTC,
|
||||||
UniSpeechSatForPreTraining,
|
UniSpeechSatForPreTraining,
|
||||||
UniSpeechSatForSequenceClassification,
|
UniSpeechSatForSequenceClassification,
|
||||||
|
UniSpeechSatForXVector,
|
||||||
UniSpeechSatModel,
|
UniSpeechSatModel,
|
||||||
UniSpeechSatPreTrainedModel,
|
UniSpeechSatPreTrainedModel,
|
||||||
)
|
)
|
||||||
@@ -3191,10 +3201,12 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from .models.wav2vec2 import (
|
from .models.wav2vec2 import (
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
Wav2Vec2ForAudioFrameClassification,
|
||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
Wav2Vec2ForMaskedLM,
|
Wav2Vec2ForMaskedLM,
|
||||||
Wav2Vec2ForPreTraining,
|
Wav2Vec2ForPreTraining,
|
||||||
Wav2Vec2ForSequenceClassification,
|
Wav2Vec2ForSequenceClassification,
|
||||||
|
Wav2Vec2ForXVector,
|
||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2PreTrainedModel,
|
Wav2Vec2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1117,6 +1117,54 @@ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {processor_class}, {model_class}
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
|
>>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> # audio file is decoded on the fly
|
||||||
|
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
|
||||||
|
>>> logits = model(**inputs).logits
|
||||||
|
>>> probabilities = torch.sigmoid(logits[0])
|
||||||
|
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
|
||||||
|
>>> labels = (probabilities > 0.5).long()
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
PT_SPEECH_XVECTOR_SAMPLE = r"""
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import {processor_class}, {model_class}
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
||||||
|
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
||||||
|
|
||||||
|
>>> feature_extractor = {processor_class}.from_pretrained('{checkpoint}')
|
||||||
|
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
||||||
|
|
||||||
|
>>> # audio file is decoded on the fly
|
||||||
|
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt")
|
||||||
|
>>> embeddings = model(**inputs).embeddings
|
||||||
|
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
||||||
|
|
||||||
|
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
|
||||||
|
>>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||||
|
>>> similarity = cosine_sim(embeddings[0], embeddings[1])
|
||||||
|
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
|
||||||
|
>>> if similarity < threshold:
|
||||||
|
... print("Speakers are not the same!")
|
||||||
|
"""
|
||||||
|
|
||||||
PT_SAMPLE_DOCSTRINGS = {
|
PT_SAMPLE_DOCSTRINGS = {
|
||||||
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
|
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
|
||||||
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
|
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
|
||||||
@@ -1128,6 +1176,8 @@ PT_SAMPLE_DOCSTRINGS = {
|
|||||||
"SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
|
"SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
|
||||||
"CTC": PT_SPEECH_CTC_SAMPLE,
|
"CTC": PT_SPEECH_CTC_SAMPLE,
|
||||||
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
|
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
|
||||||
|
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
|
||||||
|
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1419,6 +1469,10 @@ def add_code_sample_docstrings(
|
|||||||
code_sample = sample_docstrings["LMHead"]
|
code_sample = sample_docstrings["LMHead"]
|
||||||
elif "CTC" in model_class:
|
elif "CTC" in model_class:
|
||||||
code_sample = sample_docstrings["CTC"]
|
code_sample = sample_docstrings["CTC"]
|
||||||
|
elif "AudioFrameClassification" in model_class:
|
||||||
|
code_sample = sample_docstrings["AudioFrameClassification"]
|
||||||
|
elif "XVector" in model_class and modality == "audio":
|
||||||
|
code_sample = sample_docstrings["AudioXVector"]
|
||||||
elif "Model" in model_class and modality == "audio":
|
elif "Model" in model_class and modality == "audio":
|
||||||
code_sample = sample_docstrings["SpeechBaseModel"]
|
code_sample = sample_docstrings["SpeechBaseModel"]
|
||||||
elif "Model" in model_class or "Encoder" in model_class:
|
elif "Model" in model_class or "Encoder" in model_class:
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ if is_torch_available():
|
|||||||
"MODEL_WITH_LM_HEAD_MAPPING",
|
"MODEL_WITH_LM_HEAD_MAPPING",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
"AutoModelForAudioClassification",
|
"AutoModelForAudioClassification",
|
||||||
|
"AutoModelForAudioFrameClassification",
|
||||||
|
"AutoModelForAudioXVector",
|
||||||
"AutoModelForCausalLM",
|
"AutoModelForCausalLM",
|
||||||
"AutoModelForCTC",
|
"AutoModelForCTC",
|
||||||
"AutoModelForImageClassification",
|
"AutoModelForImageClassification",
|
||||||
@@ -161,6 +163,8 @@ if TYPE_CHECKING:
|
|||||||
MODEL_WITH_LM_HEAD_MAPPING,
|
MODEL_WITH_LM_HEAD_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForAudioClassification,
|
AutoModelForAudioClassification,
|
||||||
|
AutoModelForAudioFrameClassification,
|
||||||
|
AutoModelForAudioXVector,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForCTC,
|
AutoModelForCTC,
|
||||||
AutoModelForImageClassification,
|
AutoModelForImageClassification,
|
||||||
|
|||||||
@@ -538,6 +538,22 @@ MODEL_FOR_CTC_MAPPING_NAMES = OrderedDict(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Audio Classification mapping
|
||||||
|
("wav2vec2", "Wav2Vec2ForAudioFrameClassification"),
|
||||||
|
("unispeech-sat", "UniSpeechSatForAudioFrameClassification"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES = OrderedDict(
|
||||||
|
[
|
||||||
|
# Model for Audio Classification mapping
|
||||||
|
("wav2vec2", "Wav2Vec2ForXVector"),
|
||||||
|
("unispeech-sat", "UniSpeechSatForXVector"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
|
||||||
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES)
|
||||||
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES)
|
||||||
@@ -578,6 +594,10 @@ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
|||||||
)
|
)
|
||||||
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
|
MODEL_FOR_CTC_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES)
|
||||||
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES)
|
||||||
|
MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||||
|
CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES
|
||||||
|
)
|
||||||
|
MODEL_FOR_AUDIO_XVECTOR_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES)
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(_BaseAutoModelClass):
|
class AutoModel(_BaseAutoModelClass):
|
||||||
@@ -726,6 +746,22 @@ AutoModelForSpeechSeq2Seq = auto_class_update(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForAudioFrameClassification(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForAudioFrameClassification = auto_class_update(
|
||||||
|
AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForAudioXVector(_BaseAutoModelClass):
|
||||||
|
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector")
|
||||||
|
|
||||||
|
|
||||||
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
class AutoModelWithLMHead(_AutoModelWithLMHead):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config):
|
||||||
|
|||||||
@@ -42,9 +42,9 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "HubertConfig"
|
_CONFIG_FOR_DOC = "HubertConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
|
||||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 1
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
@@ -1182,7 +1182,7 @@ class HubertForSequenceClassification(HubertPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "SEWConfig"
|
_CONFIG_FOR_DOC = "SEWConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
|
||||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 1
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
@@ -1067,7 +1067,7 @@ class SEWForSequenceClassification(SEWPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "SEWDConfig"
|
_CONFIG_FOR_DOC = "SEWDConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
|
||||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 1
|
_HIDDEN_STATES_START_POSITION = 1
|
||||||
|
|
||||||
@@ -1598,7 +1598,7 @@ class SEWDForSequenceClassification(SEWDPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
|||||||
@@ -44,9 +44,9 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
_CONFIG_FOR_DOC = "UniSpeechConfig"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
|
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-large-1500h-cv"
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv"
|
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-large-1500h-cv"
|
||||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
@@ -1481,7 +1481,7 @@ class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
|||||||
@@ -27,9 +27,11 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_unispeech_sat"] = [
|
_import_structure["modeling_unispeech_sat"] = [
|
||||||
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"UniSpeechSatForAudioFrameClassification",
|
||||||
"UniSpeechSatForCTC",
|
"UniSpeechSatForCTC",
|
||||||
"UniSpeechSatForPreTraining",
|
"UniSpeechSatForPreTraining",
|
||||||
"UniSpeechSatForSequenceClassification",
|
"UniSpeechSatForSequenceClassification",
|
||||||
|
"UniSpeechSatForXVector",
|
||||||
"UniSpeechSatModel",
|
"UniSpeechSatModel",
|
||||||
"UniSpeechSatPreTrainedModel",
|
"UniSpeechSatPreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -40,9 +42,11 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_unispeech_sat import (
|
from .modeling_unispeech_sat import (
|
||||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
UniSpeechSatForAudioFrameClassification,
|
||||||
UniSpeechSatForCTC,
|
UniSpeechSatForCTC,
|
||||||
UniSpeechSatForPreTraining,
|
UniSpeechSatForPreTraining,
|
||||||
UniSpeechSatForSequenceClassification,
|
UniSpeechSatForSequenceClassification,
|
||||||
|
UniSpeechSatForXVector,
|
||||||
UniSpeechSatModel,
|
UniSpeechSatModel,
|
||||||
UniSpeechSatPreTrainedModel,
|
UniSpeechSatPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -153,6 +153,17 @@ class UniSpeechSatConfig(PretrainedConfig):
|
|||||||
instance of :class:`~transformers.UniSpeechSatForSequenceClassification`.
|
instance of :class:`~transformers.UniSpeechSatForSequenceClassification`.
|
||||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||||
Dimensionality of the projection before token mean-pooling for classification.
|
Dimensionality of the projection before token mean-pooling for classification.
|
||||||
|
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
|
||||||
|
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
|
||||||
|
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
|
||||||
|
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
|
||||||
|
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
|
||||||
|
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
|
||||||
|
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
|
||||||
|
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
|
||||||
|
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
|
||||||
|
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
|
||||||
|
Dimensionality of the `XVector` embedding vectors.
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -213,6 +224,10 @@ class UniSpeechSatConfig(PretrainedConfig):
|
|||||||
ctc_zero_infinity=False,
|
ctc_zero_infinity=False,
|
||||||
use_weighted_layer_sum=False,
|
use_weighted_layer_sum=False,
|
||||||
classifier_proj_size=256,
|
classifier_proj_size=256,
|
||||||
|
tdnn_dim=(512, 512, 512, 512, 1500),
|
||||||
|
tdnn_kernel=(5, 3, 3, 1, 1),
|
||||||
|
tdnn_dilation=(1, 2, 3, 1, 1),
|
||||||
|
xvector_output_dim=512,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
@@ -246,7 +261,6 @@ class UniSpeechSatConfig(PretrainedConfig):
|
|||||||
self.num_clusters = num_clusters
|
self.num_clusters = num_clusters
|
||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
self.classifier_proj_size = classifier_proj_size
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||||
@@ -282,3 +296,12 @@ class UniSpeechSatConfig(PretrainedConfig):
|
|||||||
# ctc loss
|
# ctc loss
|
||||||
self.ctc_loss_reduction = ctc_loss_reduction
|
self.ctc_loss_reduction = ctc_loss_reduction
|
||||||
self.ctc_zero_infinity = ctc_zero_infinity
|
self.ctc_zero_infinity = ctc_zero_infinity
|
||||||
|
|
||||||
|
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
||||||
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
|
||||||
|
# XVector-specific parameters. Feel free to ignore for other classes.
|
||||||
|
self.tdnn_dim = list(tdnn_dim)
|
||||||
|
self.tdnn_kernel = list(tdnn_kernel)
|
||||||
|
self.tdnn_dilation = list(tdnn_dilation)
|
||||||
|
self.xvector_output_dim = xvector_output_dim
|
||||||
|
|||||||
@@ -0,0 +1,110 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Convert Hubert checkpoint."""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
UniSpeechSatConfig,
|
||||||
|
UniSpeechSatForAudioFrameClassification,
|
||||||
|
UniSpeechSatForSequenceClassification,
|
||||||
|
UniSpeechSatForXVector,
|
||||||
|
Wav2Vec2FeatureExtractor,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_classification(base_model_name, hf_config, downstream_dict):
|
||||||
|
model = UniSpeechSatForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
|
||||||
|
model.projector.weight.data = downstream_dict["projector.weight"]
|
||||||
|
model.projector.bias.data = downstream_dict["projector.bias"]
|
||||||
|
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||||
|
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_diarization(base_model_name, hf_config, downstream_dict):
|
||||||
|
model = UniSpeechSatForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
|
||||||
|
model.classifier.weight.data = downstream_dict["model.linear.weight"]
|
||||||
|
model.classifier.bias.data = downstream_dict["model.linear.bias"]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_xvector(base_model_name, hf_config, downstream_dict):
|
||||||
|
model = UniSpeechSatForXVector.from_pretrained(base_model_name, config=hf_config)
|
||||||
|
model.projector.weight.data = downstream_dict["connector.weight"]
|
||||||
|
model.projector.bias.data = downstream_dict["connector.bias"]
|
||||||
|
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
|
||||||
|
model.tdnn[i].kernel.weight.data = downstream_dict[
|
||||||
|
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
|
||||||
|
]
|
||||||
|
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
|
||||||
|
|
||||||
|
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
|
||||||
|
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
|
||||||
|
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
|
||||||
|
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
|
||||||
|
model.objective.weight.data = downstream_dict["objective.W"]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
|
||||||
|
"""
|
||||||
|
Copy/paste/tweak model's weights to transformers design.
|
||||||
|
"""
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
downstream_dict = checkpoint["Downstream"]
|
||||||
|
|
||||||
|
hf_config = UniSpeechSatConfig.from_pretrained(config_path)
|
||||||
|
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||||
|
base_model_name, return_attention_mask=True, do_normalize=False
|
||||||
|
)
|
||||||
|
|
||||||
|
arch = hf_config.architectures[0]
|
||||||
|
if arch.endswith("ForSequenceClassification"):
|
||||||
|
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
|
||||||
|
elif arch.endswith("ForAudioFrameClassification"):
|
||||||
|
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
|
||||||
|
elif arch.endswith("ForXVector"):
|
||||||
|
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
|
||||||
|
|
||||||
|
if hf_config.use_weighted_layer_sum:
|
||||||
|
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||||
|
|
||||||
|
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||||
|
hf_model.save_pretrained(model_dump_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
|
||||||
|
)
|
||||||
|
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
|
||||||
|
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
|
||||||
|
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)
|
||||||
@@ -33,7 +33,7 @@ from ...file_utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
|
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_unispeech_sat import UniSpeechSatConfig
|
from .configuration_unispeech_sat import UniSpeechSatConfig
|
||||||
@@ -45,9 +45,11 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
|
_CONFIG_FOR_DOC = "UniSpeechSatConfig"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
|
_CHECKPOINT_FOR_DOC = "microsoft/unispeech-sat-base-plus"
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus"
|
_SEQ_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus"
|
||||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
_FRAME_CLASS_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sd"
|
||||||
|
_XVECTOR_CHECKPOINT = "microsoft/unispeech-sat-base-plus-sv"
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
@@ -123,6 +125,38 @@ class UniSpeechSatForPreTrainingOutput(ModelOutput):
|
|||||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class XVectorOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||||
|
Classification loss.
|
||||||
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||||
|
Classification hidden states before AMSoftmax.
|
||||||
|
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||||
|
Utterance embeddings used for vector similarity-based retrieval.
|
||||||
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
embeddings: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
|
||||||
def _compute_mask_indices(
|
def _compute_mask_indices(
|
||||||
shape: Tuple[int, int],
|
shape: Tuple[int, int],
|
||||||
@@ -1472,7 +1506,7 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
@@ -1538,3 +1572,285 @@ class UniSpeechSatForSequenceClassification(UniSpeechSatPreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
UniSpeech-SAT Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||||
|
""",
|
||||||
|
UNISPEECH_SAT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||||
|
class UniSpeechSatForAudioFrameClassification(UniSpeechSatPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.unispeech_sat = UniSpeechSatModel(config)
|
||||||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
|
will not be updated during training.
|
||||||
|
"""
|
||||||
|
self.unispeech_sat.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.unispeech_sat.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
|
checkpoint=_FRAME_CLASS_CHECKPOINT,
|
||||||
|
output_type=TokenClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.unispeech_sat(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
|
||||||
|
class AMSoftmaxLoss(nn.Module):
|
||||||
|
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
||||||
|
super(AMSoftmaxLoss, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.margin = margin
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
||||||
|
self.loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, hidden_states, labels):
|
||||||
|
labels = labels.flatten()
|
||||||
|
weight = nn.functional.normalize(self.weight, dim=0)
|
||||||
|
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
||||||
|
cos_theta = torch.mm(hidden_states, weight)
|
||||||
|
psi = cos_theta - self.margin
|
||||||
|
|
||||||
|
onehot = nn.functional.one_hot(labels, self.num_labels)
|
||||||
|
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
||||||
|
loss = self.loss(logits, labels)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
|
||||||
|
class TDNNLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_id=0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
||||||
|
self.out_conv_dim = config.tdnn_dim[layer_id]
|
||||||
|
self.kernel_size = config.tdnn_kernel[layer_id]
|
||||||
|
self.dilation = config.tdnn_dilation[layer_id]
|
||||||
|
|
||||||
|
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
|
hidden_states = nn.functional.unfold(
|
||||||
|
hidden_states,
|
||||||
|
(self.kernel_size, self.in_conv_dim),
|
||||||
|
stride=(1, self.in_conv_dim),
|
||||||
|
dilation=(self.dilation, 1),
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
hidden_states = self.kernel(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
UniSpeech-SAT Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
||||||
|
""",
|
||||||
|
UNISPEECH_SAT_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||||
|
class UniSpeechSatForXVector(UniSpeechSatPreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.unispeech_sat = UniSpeechSatModel(config)
|
||||||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||||
|
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
||||||
|
|
||||||
|
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
||||||
|
self.tdnn = nn.ModuleList(tdnn_layers)
|
||||||
|
|
||||||
|
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
||||||
|
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
||||||
|
|
||||||
|
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
|
will not be updated during training.
|
||||||
|
"""
|
||||||
|
self.unispeech_sat.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.unispeech_sat.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||||
|
"""
|
||||||
|
Computes the output length of the TDNN layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||||||
|
# 1D convolutional layer output length formula taken
|
||||||
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||||
|
return (input_length - kernel_size) // stride + 1
|
||||||
|
|
||||||
|
for kernel_size in self.config.tdnn_kernel:
|
||||||
|
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
||||||
|
|
||||||
|
return input_lengths
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
|
checkpoint=_XVECTOR_CHECKPOINT,
|
||||||
|
output_type=XVectorOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.unispeech_sat(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
|
||||||
|
for tdnn_layer in self.tdnn:
|
||||||
|
hidden_states = tdnn_layer(hidden_states)
|
||||||
|
|
||||||
|
# Statistic Pooling
|
||||||
|
if attention_mask is None:
|
||||||
|
mean_features = hidden_states.mean(dim=1)
|
||||||
|
std_features = hidden_states.std(dim=1)
|
||||||
|
else:
|
||||||
|
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
||||||
|
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
||||||
|
mean_features = []
|
||||||
|
std_features = []
|
||||||
|
for i, length in enumerate(tdnn_output_lengths):
|
||||||
|
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
||||||
|
std_features.append(hidden_states[i, :length].std(dim=0))
|
||||||
|
mean_features = torch.stack(mean_features)
|
||||||
|
std_features = torch.stack(std_features)
|
||||||
|
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
||||||
|
|
||||||
|
output_embeddings = self.feature_extractor(statistic_pooling)
|
||||||
|
logits = self.classifier(output_embeddings)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.objective(logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return XVectorOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
embeddings=output_embeddings,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -31,10 +31,12 @@ _import_structure = {
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_wav2vec2"] = [
|
_import_structure["modeling_wav2vec2"] = [
|
||||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"Wav2Vec2ForAudioFrameClassification",
|
||||||
"Wav2Vec2ForCTC",
|
"Wav2Vec2ForCTC",
|
||||||
"Wav2Vec2ForMaskedLM",
|
"Wav2Vec2ForMaskedLM",
|
||||||
"Wav2Vec2ForPreTraining",
|
"Wav2Vec2ForPreTraining",
|
||||||
"Wav2Vec2ForSequenceClassification",
|
"Wav2Vec2ForSequenceClassification",
|
||||||
|
"Wav2Vec2ForXVector",
|
||||||
"Wav2Vec2Model",
|
"Wav2Vec2Model",
|
||||||
"Wav2Vec2PreTrainedModel",
|
"Wav2Vec2PreTrainedModel",
|
||||||
]
|
]
|
||||||
@@ -65,10 +67,12 @@ if TYPE_CHECKING:
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_wav2vec2 import (
|
from .modeling_wav2vec2 import (
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
Wav2Vec2ForAudioFrameClassification,
|
||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
Wav2Vec2ForMaskedLM,
|
Wav2Vec2ForMaskedLM,
|
||||||
Wav2Vec2ForPreTraining,
|
Wav2Vec2ForPreTraining,
|
||||||
Wav2Vec2ForSequenceClassification,
|
Wav2Vec2ForSequenceClassification,
|
||||||
|
Wav2Vec2ForXVector,
|
||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2PreTrainedModel,
|
Wav2Vec2PreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -80,10 +80,10 @@ class Wav2Vec2Config(PretrainedConfig):
|
|||||||
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
|
feature extractor. The length of `conv_dim` defines the number of 1D convolutional layers.
|
||||||
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
|
conv_stride (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 2, 2, 2, 2, 2, 2)`):
|
||||||
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
|
A tuple of integers defining the stride of each 1D convolutional layer in the feature extractor. The length
|
||||||
of `conv_stride` defines the number of convolutional layers and has to match the the length of `conv_dim`.
|
of `conv_stride` defines the number of convolutional layers and has to match the length of `conv_dim`.
|
||||||
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
|
conv_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(10, 3, 3, 3, 3, 3, 3)`):
|
||||||
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
|
A tuple of integers defining the kernel size of each 1D convolutional layer in the feature extractor. The
|
||||||
length of `conv_kernel` defines the number of convolutional layers and has to match the the length of
|
length of `conv_kernel` defines the number of convolutional layers and has to match the length of
|
||||||
`conv_dim`.
|
`conv_dim`.
|
||||||
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
conv_bias (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether the 1D convolutional layers have a bias.
|
Whether the 1D convolutional layers have a bias.
|
||||||
@@ -153,6 +153,17 @@ class Wav2Vec2Config(PretrainedConfig):
|
|||||||
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||||
Dimensionality of the projection before token mean-pooling for classification.
|
Dimensionality of the projection before token mean-pooling for classification.
|
||||||
|
tdnn_dim (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(512, 512, 512, 512, 1500)`):
|
||||||
|
A tuple of integers defining the number of output channels of each 1D convolutional layer in the `TDNN`
|
||||||
|
module of the `XVector` model. The length of `tdnn_dim` defines the number of `TDNN` layers.
|
||||||
|
tdnn_kernel (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 3, 3, 1, 1)`):
|
||||||
|
A tuple of integers defining the kernel size of each 1D convolutional layer in the `TDNN` module of the
|
||||||
|
`XVector` model. The length of `tdnn_kernel` has to match the length of `tdnn_dim`.
|
||||||
|
tdnn_dilation (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(1, 2, 3, 1, 1)`):
|
||||||
|
A tuple of integers defining the dilation factor of each 1D convolutional layer in `TDNN` module of the
|
||||||
|
`XVector` model. The length of `tdnn_dilation` has to match the length of `tdnn_dim`.
|
||||||
|
xvector_output_dim (:obj:`int`, `optional`, defaults to 512):
|
||||||
|
Dimensionality of the `XVector` embedding vectors.
|
||||||
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
||||||
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
||||||
@@ -226,6 +237,10 @@ class Wav2Vec2Config(PretrainedConfig):
|
|||||||
ctc_zero_infinity=False,
|
ctc_zero_infinity=False,
|
||||||
use_weighted_layer_sum=False,
|
use_weighted_layer_sum=False,
|
||||||
classifier_proj_size=256,
|
classifier_proj_size=256,
|
||||||
|
tdnn_dim=(512, 512, 512, 512, 1500),
|
||||||
|
tdnn_kernel=(5, 3, 3, 1, 1),
|
||||||
|
tdnn_dilation=(1, 2, 3, 1, 1),
|
||||||
|
xvector_output_dim=512,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
bos_token_id=1,
|
bos_token_id=1,
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
@@ -262,7 +277,6 @@ class Wav2Vec2Config(PretrainedConfig):
|
|||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
self.use_weighted_layer_sum = use_weighted_layer_sum
|
self.use_weighted_layer_sum = use_weighted_layer_sum
|
||||||
self.classifier_proj_size = classifier_proj_size
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||||
@@ -305,3 +319,12 @@ class Wav2Vec2Config(PretrainedConfig):
|
|||||||
self.adapter_stride = adapter_stride
|
self.adapter_stride = adapter_stride
|
||||||
self.num_adapter_layers = num_adapter_layers
|
self.num_adapter_layers = num_adapter_layers
|
||||||
self.output_hidden_size = output_hidden_size or hidden_size
|
self.output_hidden_size = output_hidden_size or hidden_size
|
||||||
|
|
||||||
|
# SequenceClassification-specific parameter. Feel free to ignore for other classes.
|
||||||
|
self.classifier_proj_size = classifier_proj_size
|
||||||
|
|
||||||
|
# XVector-specific parameters. Feel free to ignore for other classes.
|
||||||
|
self.tdnn_dim = list(tdnn_dim)
|
||||||
|
self.tdnn_kernel = list(tdnn_kernel)
|
||||||
|
self.tdnn_dilation = list(tdnn_dilation)
|
||||||
|
self.xvector_output_dim = xvector_output_dim
|
||||||
|
|||||||
@@ -19,13 +19,52 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import Wav2Vec2Config, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, logging
|
from transformers import (
|
||||||
|
Wav2Vec2Config,
|
||||||
|
Wav2Vec2FeatureExtractor,
|
||||||
|
Wav2Vec2ForAudioFrameClassification,
|
||||||
|
Wav2Vec2ForSequenceClassification,
|
||||||
|
Wav2Vec2ForXVector,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
SUPPORTED_MODELS = ["UtteranceLevel"]
|
|
||||||
|
def convert_classification(base_model_name, hf_config, downstream_dict):
|
||||||
|
model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config)
|
||||||
|
model.projector.weight.data = downstream_dict["projector.weight"]
|
||||||
|
model.projector.bias.data = downstream_dict["projector.bias"]
|
||||||
|
model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
||||||
|
model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_diarization(base_model_name, hf_config, downstream_dict):
|
||||||
|
model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config)
|
||||||
|
model.classifier.weight.data = downstream_dict["model.linear.weight"]
|
||||||
|
model.classifier.bias.data = downstream_dict["model.linear.bias"]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_xvector(base_model_name, hf_config, downstream_dict):
|
||||||
|
model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config)
|
||||||
|
model.projector.weight.data = downstream_dict["connector.weight"]
|
||||||
|
model.projector.bias.data = downstream_dict["connector.bias"]
|
||||||
|
for i, kernel_size in enumerate(hf_config.tdnn_kernel):
|
||||||
|
model.tdnn[i].kernel.weight.data = downstream_dict[
|
||||||
|
f"model.framelevel_feature_extractor.module.{i}.kernel.weight"
|
||||||
|
]
|
||||||
|
model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"]
|
||||||
|
|
||||||
|
model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"]
|
||||||
|
model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"]
|
||||||
|
model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"]
|
||||||
|
model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"]
|
||||||
|
model.objective.weight.data = downstream_dict["objective.W"]
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -34,24 +73,26 @@ def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, mode
|
|||||||
Copy/paste/tweak model's weights to transformers design.
|
Copy/paste/tweak model's weights to transformers design.
|
||||||
"""
|
"""
|
||||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
|
|
||||||
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
|
|
||||||
|
|
||||||
downstream_dict = checkpoint["Downstream"]
|
downstream_dict = checkpoint["Downstream"]
|
||||||
|
|
||||||
hf_congfig = Wav2Vec2Config.from_pretrained(config_path)
|
hf_config = Wav2Vec2Config.from_pretrained(config_path)
|
||||||
hf_model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
|
|
||||||
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||||
base_model_name, return_attention_mask=True, do_normalize=False
|
base_model_name, return_attention_mask=True, do_normalize=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if hf_congfig.use_weighted_layer_sum:
|
arch = hf_config.architectures[0]
|
||||||
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
if arch.endswith("ForSequenceClassification"):
|
||||||
|
hf_model = convert_classification(base_model_name, hf_config, downstream_dict)
|
||||||
|
elif arch.endswith("ForAudioFrameClassification"):
|
||||||
|
hf_model = convert_diarization(base_model_name, hf_config, downstream_dict)
|
||||||
|
elif arch.endswith("ForXVector"):
|
||||||
|
hf_model = convert_xvector(base_model_name, hf_config, downstream_dict)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}")
|
||||||
|
|
||||||
hf_model.projector.weight.data = downstream_dict["projector.weight"]
|
if hf_config.use_weighted_layer_sum:
|
||||||
hf_model.projector.bias.data = downstream_dict["projector.bias"]
|
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
|
||||||
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
|
|
||||||
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
|
|
||||||
|
|
||||||
hf_feature_extractor.save_pretrained(model_dump_path)
|
hf_feature_extractor.save_pretrained(model_dump_path)
|
||||||
hf_model.save_pretrained(model_dump_path)
|
hf_model.save_pretrained(model_dump_path)
|
||||||
|
|||||||
@@ -34,7 +34,13 @@ from ...file_utils import (
|
|||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput, SequenceClassifierOutput
|
from ...modeling_outputs import (
|
||||||
|
BaseModelOutput,
|
||||||
|
CausalLMOutput,
|
||||||
|
MaskedLMOutput,
|
||||||
|
SequenceClassifierOutput,
|
||||||
|
TokenClassifierOutput,
|
||||||
|
)
|
||||||
from ...modeling_utils import PreTrainedModel
|
from ...modeling_utils import PreTrainedModel
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||||
@@ -45,9 +51,11 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||||
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
|
||||||
|
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
||||||
|
|
||||||
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks"
|
||||||
_SEQ_CLASS_PROCESSOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
|
_FRAME_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-sd"
|
||||||
|
_XVECTOR_CHECKPOINT = "superb/wav2vec2-base-superb-sv"
|
||||||
|
|
||||||
_HIDDEN_STATES_START_POSITION = 2
|
_HIDDEN_STATES_START_POSITION = 2
|
||||||
|
|
||||||
@@ -93,7 +101,7 @@ class Wav2Vec2BaseModelOutput(ModelOutput):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
Output type of :class:`~transformers.Wav2Vec2ForPreTrainingOutput`, with potential hidden states and attentions.
|
Output type of :class:`~transformers.Wav2Vec2ForPreTraining`, with potential hidden states and attentions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
loss (`optional`, returned when :obj:`sample_negative_indices` are passed, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
||||||
@@ -132,6 +140,38 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
|
|||||||
diversity_loss: Optional[torch.FloatTensor] = None
|
diversity_loss: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class XVectorOutput(ModelOutput):
|
||||||
|
"""
|
||||||
|
Output type of :class:`~transformers.Wav2Vec2ForXVector`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
||||||
|
Classification loss.
|
||||||
|
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||||
|
Classification hidden states before AMSoftmax.
|
||||||
|
embeddings (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.xvector_output_dim)`):
|
||||||
|
Utterance embeddings used for vector similarity-based retrieval.
|
||||||
|
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||||
|
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||||
|
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
||||||
|
heads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
loss: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
embeddings: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
def _compute_mask_indices(
|
def _compute_mask_indices(
|
||||||
shape: Tuple[int, int],
|
shape: Tuple[int, int],
|
||||||
mask_prob: float,
|
mask_prob: float,
|
||||||
@@ -1707,7 +1747,7 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
processor_class=_SEQ_CLASS_PROCESSOR_FOR_DOC,
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
checkpoint=_SEQ_CLASS_CHECKPOINT,
|
||||||
output_type=SequenceClassifierOutput,
|
output_type=SequenceClassifierOutput,
|
||||||
config_class=_CONFIG_FOR_DOC,
|
config_class=_CONFIG_FOR_DOC,
|
||||||
@@ -1773,3 +1813,281 @@ class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
|
|||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization.
|
||||||
|
""",
|
||||||
|
WAV_2_VEC_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||||
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
|
will not be updated during training.
|
||||||
|
"""
|
||||||
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.wav2vec2.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
|
checkpoint=_FRAME_CLASS_CHECKPOINT,
|
||||||
|
output_type=TokenClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.wav2vec2(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
logits = self.classifier(hidden_states)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return output
|
||||||
|
|
||||||
|
return TokenClassifierOutput(
|
||||||
|
loss=None,
|
||||||
|
logits=logits,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AMSoftmaxLoss(nn.Module):
|
||||||
|
def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
|
||||||
|
super(AMSoftmaxLoss, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
self.margin = margin
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
|
||||||
|
self.loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, hidden_states, labels):
|
||||||
|
labels = labels.flatten()
|
||||||
|
weight = nn.functional.normalize(self.weight, dim=0)
|
||||||
|
hidden_states = nn.functional.normalize(hidden_states, dim=1)
|
||||||
|
cos_theta = torch.mm(hidden_states, weight)
|
||||||
|
psi = cos_theta - self.margin
|
||||||
|
|
||||||
|
onehot = nn.functional.one_hot(labels, self.num_labels)
|
||||||
|
logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
|
||||||
|
loss = self.loss(logits, labels)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TDNNLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_id=0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
|
||||||
|
self.out_conv_dim = config.tdnn_dim[layer_id]
|
||||||
|
self.kernel_size = config.tdnn_kernel[layer_id]
|
||||||
|
self.dilation = config.tdnn_dilation[layer_id]
|
||||||
|
|
||||||
|
self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
|
||||||
|
self.activation = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
hidden_states = hidden_states.unsqueeze(1)
|
||||||
|
hidden_states = nn.functional.unfold(
|
||||||
|
hidden_states,
|
||||||
|
(self.kernel_size, self.in_conv_dim),
|
||||||
|
stride=(1, self.in_conv_dim),
|
||||||
|
dilation=(self.dilation, 1),
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
hidden_states = self.kernel(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = self.activation(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings(
|
||||||
|
"""
|
||||||
|
Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
|
||||||
|
""",
|
||||||
|
WAV_2_VEC_2_START_DOCSTRING,
|
||||||
|
)
|
||||||
|
class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.wav2vec2 = Wav2Vec2Model(config)
|
||||||
|
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
|
||||||
|
if config.use_weighted_layer_sum:
|
||||||
|
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
|
||||||
|
self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
|
||||||
|
|
||||||
|
tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
|
||||||
|
self.tdnn = nn.ModuleList(tdnn_layers)
|
||||||
|
|
||||||
|
self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
|
||||||
|
self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
|
||||||
|
|
||||||
|
self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def freeze_feature_extractor(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the feature extractor so that its parameters
|
||||||
|
will not be updated during training.
|
||||||
|
"""
|
||||||
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
|
def freeze_base_model(self):
|
||||||
|
"""
|
||||||
|
Calling this function will disable the gradient computation for the base model so that its parameters will not
|
||||||
|
be updated during training. Only the classification head will be updated.
|
||||||
|
"""
|
||||||
|
for param in self.wav2vec2.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||||
|
"""
|
||||||
|
Computes the output length of the TDNN layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||||||
|
# 1D convolutional layer output length formula taken
|
||||||
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||||
|
return (input_length - kernel_size) // stride + 1
|
||||||
|
|
||||||
|
for kernel_size in self.config.tdnn_kernel:
|
||||||
|
input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
|
||||||
|
|
||||||
|
return input_lengths
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||||
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
|
checkpoint=_XVECTOR_CHECKPOINT,
|
||||||
|
output_type=XVectorOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="audio",
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
labels=None,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
||||||
|
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
||||||
|
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
||||||
|
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
|
||||||
|
|
||||||
|
outputs = self.wav2vec2(
|
||||||
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.use_weighted_layer_sum:
|
||||||
|
hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
|
||||||
|
hidden_states = torch.stack(hidden_states, dim=1)
|
||||||
|
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
|
||||||
|
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
|
||||||
|
for tdnn_layer in self.tdnn:
|
||||||
|
hidden_states = tdnn_layer(hidden_states)
|
||||||
|
|
||||||
|
# Statistic Pooling
|
||||||
|
if attention_mask is None:
|
||||||
|
mean_features = hidden_states.mean(dim=1)
|
||||||
|
std_features = hidden_states.std(dim=1)
|
||||||
|
else:
|
||||||
|
feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
|
||||||
|
tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
|
||||||
|
mean_features = []
|
||||||
|
std_features = []
|
||||||
|
for i, length in enumerate(tdnn_output_lengths):
|
||||||
|
mean_features.append(hidden_states[i, :length].mean(dim=0))
|
||||||
|
std_features.append(hidden_states[i, :length].std(dim=0))
|
||||||
|
mean_features = torch.stack(mean_features)
|
||||||
|
std_features = torch.stack(std_features)
|
||||||
|
statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
|
||||||
|
|
||||||
|
output_embeddings = self.feature_extractor(statistic_pooling)
|
||||||
|
logits = self.classifier(output_embeddings)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.objective(logits, labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
|
||||||
|
return ((loss,) + output) if loss is not None else output
|
||||||
|
|
||||||
|
return XVectorOutput(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
embeddings=output_embeddings,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|||||||
@@ -422,6 +422,30 @@ class AutoModelForAudioClassification:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForAudioFrameClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class AutoModelForAudioXVector:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class AutoModelForCausalLM:
|
class AutoModelForCausalLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -4896,6 +4920,11 @@ class UniSpeechPreTrainedModel:
|
|||||||
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
UNISPEECH_SAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class UniSpeechSatForAudioFrameClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class UniSpeechSatForCTC:
|
class UniSpeechSatForCTC:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -4918,6 +4947,11 @@ class UniSpeechSatForSequenceClassification:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class UniSpeechSatForXVector:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class UniSpeechSatModel:
|
class UniSpeechSatModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -5072,6 +5106,11 @@ class ViTPreTrainedModel:
|
|||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ForAudioFrameClassification:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2ForCTC:
|
class Wav2Vec2ForCTC:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -5106,6 +5145,11 @@ class Wav2Vec2ForSequenceClassification:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Vec2ForXVector:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class Wav2Vec2Model:
|
class Wav2Vec2Model:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|||||||
@@ -33,9 +33,11 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
UniSpeechSatForAudioFrameClassification,
|
||||||
UniSpeechSatForCTC,
|
UniSpeechSatForCTC,
|
||||||
UniSpeechSatForPreTraining,
|
UniSpeechSatForPreTraining,
|
||||||
UniSpeechSatForSequenceClassification,
|
UniSpeechSatForSequenceClassification,
|
||||||
|
UniSpeechSatForXVector,
|
||||||
UniSpeechSatModel,
|
UniSpeechSatModel,
|
||||||
Wav2Vec2FeatureExtractor,
|
Wav2Vec2FeatureExtractor,
|
||||||
Wav2Vec2Processor,
|
Wav2Vec2Processor,
|
||||||
@@ -70,6 +72,10 @@ class UniSpeechSatModelTester:
|
|||||||
mask_time_length=2,
|
mask_time_length=2,
|
||||||
vocab_size=32,
|
vocab_size=32,
|
||||||
do_stable_layer_norm=False,
|
do_stable_layer_norm=False,
|
||||||
|
tdnn_dim=(32, 32),
|
||||||
|
tdnn_kernel=(3, 3),
|
||||||
|
tdnn_dilation=(1, 1),
|
||||||
|
xvector_output_dim=32,
|
||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@@ -97,6 +103,10 @@ class UniSpeechSatModelTester:
|
|||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
self.mask_time_prob = mask_time_prob
|
self.mask_time_prob = mask_time_prob
|
||||||
self.mask_time_length = mask_time_length
|
self.mask_time_length = mask_time_length
|
||||||
|
self.tdnn_dim = tdnn_dim
|
||||||
|
self.tdnn_kernel = tdnn_kernel
|
||||||
|
self.tdnn_dilation = tdnn_dilation
|
||||||
|
self.xvector_output_dim = xvector_output_dim
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
output_seq_length = self.seq_length
|
output_seq_length = self.seq_length
|
||||||
@@ -135,6 +145,10 @@ class UniSpeechSatModelTester:
|
|||||||
hidden_act=self.hidden_act,
|
hidden_act=self.hidden_act,
|
||||||
initializer_range=self.initializer_range,
|
initializer_range=self.initializer_range,
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
|
tdnn_dim=self.tdnn_dim,
|
||||||
|
tdnn_kernel=self.tdnn_kernel,
|
||||||
|
tdnn_dilation=self.tdnn_dilation,
|
||||||
|
xvector_output_dim=self.xvector_output_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, input_values, attention_mask):
|
def create_and_check_model(self, config, input_values, attention_mask):
|
||||||
@@ -277,6 +291,30 @@ class UniSpeechSatModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_xvector_training(self, config, *args):
|
||||||
|
config.ctc_zero_infinity = True
|
||||||
|
model = UniSpeechSatForXVector(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# freeze everything but the classification head
|
||||||
|
model.freeze_base_model()
|
||||||
|
|
||||||
|
# use a longer sequence length to account for TDNN temporal downsampling
|
||||||
|
input_values = floats_tensor([self.batch_size, self.seq_length * 2], self.vocab_size)
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
|
||||||
|
loss = model(input_values, labels=labels).loss
|
||||||
|
self.parent.assertFalse(torch.isinf(loss).item())
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
model = UniSpeechSatForCTC(config)
|
model = UniSpeechSatForCTC(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -300,7 +338,14 @@ class UniSpeechSatModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(UniSpeechSatForCTC, UniSpeechSatForPreTraining, UniSpeechSatModel, UniSpeechSatForSequenceClassification)
|
(
|
||||||
|
UniSpeechSatForCTC,
|
||||||
|
UniSpeechSatForPreTraining,
|
||||||
|
UniSpeechSatModel,
|
||||||
|
UniSpeechSatForSequenceClassification,
|
||||||
|
UniSpeechSatForAudioFrameClassification,
|
||||||
|
UniSpeechSatForXVector,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -335,6 +380,10 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xvector_train(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||||
|
|
||||||
def test_labels_out_of_vocab(self):
|
def test_labels_out_of_vocab(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
@@ -417,6 +466,7 @@ class UniSpeechSatModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"feature_projection.projection.weight",
|
"feature_projection.projection.weight",
|
||||||
"feature_projection.projection.bias",
|
"feature_projection.projection.bias",
|
||||||
"label_embeddings_concat",
|
"label_embeddings_concat",
|
||||||
|
"objective.weight",
|
||||||
]
|
]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any([x in name for x in uniform_init_parms]):
|
||||||
@@ -623,6 +673,7 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"feature_projection.projection.weight",
|
"feature_projection.projection.weight",
|
||||||
"feature_projection.projection.bias",
|
"feature_projection.projection.bias",
|
||||||
"label_embeddings_concat",
|
"label_embeddings_concat",
|
||||||
|
"objective.weight",
|
||||||
]
|
]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any([x in name for x in uniform_init_parms]):
|
||||||
@@ -811,3 +862,56 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
|
self.assertTrue(torch.allclose(outputs.last_hidden_state[:, :2, -2:], expected_hidden_states_slice, atol=1e-3))
|
||||||
|
|
||||||
|
def test_inference_diarization(self):
|
||||||
|
model = UniSpeechSatForAudioFrameClassification.from_pretrained("anton-l/unispeech-sat-base-plus-sd").to(
|
||||||
|
torch_device
|
||||||
|
)
|
||||||
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sd")
|
||||||
|
input_data = self._load_superb("sd", 4)
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||||
|
|
||||||
|
input_values = inputs.input_values.to(torch_device)
|
||||||
|
attention_mask = inputs.attention_mask.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask)
|
||||||
|
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||||
|
labels = (outputs.logits > 0).long()
|
||||||
|
|
||||||
|
# s3prl logits for the same batch
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[[-5.6119, -5.5845], [-3.7772, -5.4824], [-3.6914, -5.1619], [-4.7560, -5.0496]],
|
||||||
|
[[-6.3785, -4.8365], [-5.5863, -5.4149], [-5.5639, -4.8469], [-6.1511, -4.0052]],
|
||||||
|
[[-6.0355, -3.7414], [-5.5968, -4.8061], [-5.4620, -4.7310], [-5.5864, -4.6078]],
|
||||||
|
[[-5.9493, -4.8963], [-4.4050, -5.4476], [-4.1755, -5.1395], [-4.0272, -4.3705]],
|
||||||
|
],
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
self.assertEqual(labels[0, :, 0].sum(), 270)
|
||||||
|
self.assertEqual(labels[0, :, 1].sum(), 647)
|
||||||
|
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||||
|
|
||||||
|
def test_inference_speaker_verification(self):
|
||||||
|
model = UniSpeechSatForXVector.from_pretrained("anton-l/unispeech-sat-base-plus-sv").to(torch_device)
|
||||||
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/unispeech-sat-base-plus-sv")
|
||||||
|
input_data = self._load_superb("si", 4)
|
||||||
|
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="pt", padding=True)
|
||||||
|
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
input_values = inputs.input_values.to(torch_device)
|
||||||
|
attention_mask = inputs.attention_mask.to(torch_device)
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||||
|
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1)
|
||||||
|
|
||||||
|
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||||
|
# id10002 vs id10002
|
||||||
|
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).item(), 0.9671, 3)
|
||||||
|
# id10006 vs id10002
|
||||||
|
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).item(), 0.4941, 3)
|
||||||
|
# id10002 vs id10004
|
||||||
|
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).item(), 0.5616, 3)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(outputs.loss.item(), 18.5925, 3)
|
||||||
|
|||||||
@@ -44,10 +44,12 @@ if is_torch_available():
|
|||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
Wav2Vec2FeatureExtractor,
|
Wav2Vec2FeatureExtractor,
|
||||||
|
Wav2Vec2ForAudioFrameClassification,
|
||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
Wav2Vec2ForMaskedLM,
|
Wav2Vec2ForMaskedLM,
|
||||||
Wav2Vec2ForPreTraining,
|
Wav2Vec2ForPreTraining,
|
||||||
Wav2Vec2ForSequenceClassification,
|
Wav2Vec2ForSequenceClassification,
|
||||||
|
Wav2Vec2ForXVector,
|
||||||
Wav2Vec2Model,
|
Wav2Vec2Model,
|
||||||
Wav2Vec2Processor,
|
Wav2Vec2Processor,
|
||||||
)
|
)
|
||||||
@@ -96,6 +98,10 @@ class Wav2Vec2ModelTester:
|
|||||||
do_stable_layer_norm=False,
|
do_stable_layer_norm=False,
|
||||||
num_adapter_layers=1,
|
num_adapter_layers=1,
|
||||||
adapter_stride=2,
|
adapter_stride=2,
|
||||||
|
tdnn_dim=(32, 32),
|
||||||
|
tdnn_kernel=(5, 3),
|
||||||
|
tdnn_dilation=(1, 2),
|
||||||
|
xvector_output_dim=32,
|
||||||
scope=None,
|
scope=None,
|
||||||
):
|
):
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
@@ -126,6 +132,10 @@ class Wav2Vec2ModelTester:
|
|||||||
self.mask_time_prob = mask_time_prob
|
self.mask_time_prob = mask_time_prob
|
||||||
self.mask_time_length = mask_time_length
|
self.mask_time_length = mask_time_length
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
self.tdnn_dim = tdnn_dim
|
||||||
|
self.tdnn_kernel = tdnn_kernel
|
||||||
|
self.tdnn_dilation = tdnn_dilation
|
||||||
|
self.xvector_output_dim = xvector_output_dim
|
||||||
|
|
||||||
output_seq_length = self.seq_length
|
output_seq_length = self.seq_length
|
||||||
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
for kernel, stride in zip(self.conv_kernel, self.conv_stride):
|
||||||
@@ -168,6 +178,10 @@ class Wav2Vec2ModelTester:
|
|||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
num_adapter_layers=self.num_adapter_layers,
|
num_adapter_layers=self.num_adapter_layers,
|
||||||
adapter_stride=self.adapter_stride,
|
adapter_stride=self.adapter_stride,
|
||||||
|
tdnn_dim=self.tdnn_dim,
|
||||||
|
tdnn_kernel=self.tdnn_kernel,
|
||||||
|
tdnn_dilation=self.tdnn_dilation,
|
||||||
|
xvector_output_dim=self.xvector_output_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_model(self, config, input_values, attention_mask):
|
def create_and_check_model(self, config, input_values, attention_mask):
|
||||||
@@ -332,6 +346,29 @@ class Wav2Vec2ModelTester:
|
|||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
def check_xvector_training(self, config, input_values, *args):
|
||||||
|
config.ctc_zero_infinity = True
|
||||||
|
model = Wav2Vec2ForXVector(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# freeze everything but the classification head
|
||||||
|
model.freeze_base_model()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
labels = ids_tensor((input_values.shape[0], 1), len(model.config.id2label))
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
|
||||||
|
loss = model(input_values, labels=labels).loss
|
||||||
|
self.parent.assertFalse(torch.isinf(loss).item())
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
def check_labels_out_of_vocab(self, config, input_values, *args):
|
def check_labels_out_of_vocab(self, config, input_values, *args):
|
||||||
model = Wav2Vec2ForCTC(config)
|
model = Wav2Vec2ForCTC(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
@@ -398,6 +435,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xvector_train(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||||
|
|
||||||
def test_labels_out_of_vocab(self):
|
def test_labels_out_of_vocab(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
@@ -489,6 +530,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"project_q.bias",
|
"project_q.bias",
|
||||||
"feature_projection.projection.weight",
|
"feature_projection.projection.weight",
|
||||||
"feature_projection.projection.bias",
|
"feature_projection.projection.bias",
|
||||||
|
"objective.weight",
|
||||||
]
|
]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any([x in name for x in uniform_init_parms]):
|
||||||
@@ -573,7 +615,15 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (
|
all_model_classes = (
|
||||||
(Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForSequenceClassification, Wav2Vec2ForPreTraining)
|
(
|
||||||
|
Wav2Vec2ForCTC,
|
||||||
|
Wav2Vec2Model,
|
||||||
|
Wav2Vec2ForMaskedLM,
|
||||||
|
Wav2Vec2ForSequenceClassification,
|
||||||
|
Wav2Vec2ForPreTraining,
|
||||||
|
Wav2Vec2ForAudioFrameClassification,
|
||||||
|
Wav2Vec2ForXVector,
|
||||||
|
)
|
||||||
if is_torch_available()
|
if is_torch_available()
|
||||||
else ()
|
else ()
|
||||||
)
|
)
|
||||||
@@ -622,6 +672,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
self.model_tester.check_seq_classifier_training(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xvector_train(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_xvector_training(*config_and_inputs)
|
||||||
|
|
||||||
def test_labels_out_of_vocab(self):
|
def test_labels_out_of_vocab(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
self.model_tester.check_labels_out_of_vocab(*config_and_inputs)
|
||||||
@@ -703,6 +757,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"project_q.bias",
|
"project_q.bias",
|
||||||
"feature_projection.projection.weight",
|
"feature_projection.projection.weight",
|
||||||
"feature_projection.projection.bias",
|
"feature_projection.projection.bias",
|
||||||
|
"objective.weight",
|
||||||
]
|
]
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
if any([x in name for x in uniform_init_parms]):
|
if any([x in name for x in uniform_init_parms]):
|
||||||
@@ -1369,3 +1424,54 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
transcription = processor.batch_decode(logits.cpu().numpy()).text
|
transcription = processor.batch_decode(logits.cpu().numpy()).text
|
||||||
|
|
||||||
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")
|
||||||
|
|
||||||
|
def test_inference_diarization(self):
|
||||||
|
model = Wav2Vec2ForAudioFrameClassification.from_pretrained("anton-l/wav2vec2-base-superb-sd").to(torch_device)
|
||||||
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sd")
|
||||||
|
input_data = self._load_superb("sd", 4)
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||||
|
|
||||||
|
input_values = inputs.input_values.to(torch_device)
|
||||||
|
attention_mask = inputs.attention_mask.to(torch_device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask)
|
||||||
|
# labels is a one-hot array of shape (num_frames, num_speakers)
|
||||||
|
labels = (outputs.logits > 0).long()
|
||||||
|
|
||||||
|
# s3prl logits for the same batch
|
||||||
|
expected_logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[[-5.2807, -5.1272], [-5.4059, -4.7757], [-5.2764, -4.9621], [-5.0117, -4.5851]],
|
||||||
|
[[-1.7643, -0.5462], [-1.7369, -0.2649], [-1.5066, -0.6200], [-4.5703, -2.4863]],
|
||||||
|
[[-0.8656, -0.4783], [-0.8899, -0.3289], [-0.9267, -0.5781], [-0.7817, -0.4619]],
|
||||||
|
[[-4.8625, -2.5316], [-5.2339, -2.2155], [-4.9835, -2.0344], [-4.4727, -1.8421]],
|
||||||
|
],
|
||||||
|
device=torch_device,
|
||||||
|
)
|
||||||
|
self.assertEqual(labels[0, :, 0].sum(), 555)
|
||||||
|
self.assertEqual(labels[0, :, 1].sum(), 299)
|
||||||
|
self.assertTrue(torch.allclose(outputs.logits[:, :4], expected_logits, atol=1e-3))
|
||||||
|
|
||||||
|
def test_inference_speaker_verification(self):
|
||||||
|
model = Wav2Vec2ForXVector.from_pretrained("anton-l/wav2vec2-base-superb-sv").to(torch_device)
|
||||||
|
processor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/wav2vec2-base-superb-sv")
|
||||||
|
input_data = self._load_superb("si", 4)
|
||||||
|
|
||||||
|
inputs = processor(input_data["speech"], return_tensors="pt", padding=True, sampling_rate=16_000)
|
||||||
|
labels = torch.tensor([5, 1, 1, 3], device=torch_device).T
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
input_values = inputs.input_values.to(torch_device)
|
||||||
|
attention_mask = inputs.attention_mask.to(torch_device)
|
||||||
|
outputs = model(input_values, attention_mask=attention_mask, labels=labels)
|
||||||
|
embeddings = torch.nn.functional.normalize(outputs.embeddings, dim=-1).cpu()
|
||||||
|
|
||||||
|
cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
||||||
|
# id10002 vs id10002
|
||||||
|
self.assertAlmostEqual(cosine_sim(embeddings[1], embeddings[2]).numpy(), 0.9758, 3)
|
||||||
|
# id10006 vs id10002
|
||||||
|
self.assertAlmostEqual(cosine_sim(embeddings[0], embeddings[1]).numpy(), 0.7579, 3)
|
||||||
|
# id10002 vs id10004
|
||||||
|
self.assertAlmostEqual(cosine_sim(embeddings[2], embeddings[3]).numpy(), 0.7594, 3)
|
||||||
|
|
||||||
|
self.assertAlmostEqual(outputs.loss.item(), 17.7963, 3)
|
||||||
|
|||||||
@@ -74,6 +74,12 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
|
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
|
||||||
"AutoModelForNextSentencePrediction",
|
"AutoModelForNextSentencePrediction",
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"audio-frame-classification",
|
||||||
|
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
|
||||||
|
"AutoModelForAudioFrameClassification",
|
||||||
|
),
|
||||||
|
("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user