Add Fine-Tuning for Wav2Vec2 (#10145)
* add encode labels function to tokenizer * start adding finetuning * init dropout * upload * correct convert script * apply changes * fix second typo * make first dummy training run * adapt convert script * push confg for comparison * remove conf * finish training * adapt data collator * add research folder * update according to fairseq feedback * some minor corrections * refactor masking indices a bit * some minor changes * clean tokenizer * finish clean-up * remove previous logic * update run script * correct training * finish changes * finish model * correct bug * fix training a bit more * add some tests * finish gradient checkpointing * finish example * correct gradient checkpointing * improve tokenization method * revert changes in tokenizer * revert general change * adapt fine-tuning * update * save intermediate test * Update README.md * finish finetuning * delete conversion script * Update src/transformers/models/wav2vec2/configuration_wav2vec2.py * Update src/transformers/models/wav2vec2/processing_wav2vec2.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * finish wav2vec2 script * finish wav2vec2 fine-tuning * finalize test * correct test * adapt tests * finish * remove test file Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
3c733f3208
commit
0234de8418
@@ -62,7 +62,7 @@ Wav2Vec2Processor
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2Processor
|
||||
:members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
|
||||
:members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
|
||||
|
||||
|
||||
Wav2Vec2Model
|
||||
|
||||
8
examples/research_projects/wav2vec2/README.md
Normal file
8
examples/research_projects/wav2vec2/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
## Fine-tuning Wav2Vec2
|
||||
|
||||
The `run_training.py` script allows one to finetune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
|
||||
|
||||
This finetuning script can also be run as a google colab [TODO: here]( ).
|
||||
|
||||
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
|
||||
Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag.
|
||||
21
examples/research_projects/wav2vec2/finetune_base_100.sh
Executable file
21
examples/research_projects/wav2vec2/finetune_base_100.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-base-100h" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--per_device_eval_batch_size="32" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_total_limit="3" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-base" \
|
||||
--fp16 \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--preprocessing_num_workers="32" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor
|
||||
21
examples/research_projects/wav2vec2/finetune_large_lv60_100.sh
Executable file
21
examples/research_projects/wav2vec2/finetune_large_lv60_100.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-large-lv60-100h" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="16" \
|
||||
--per_device_eval_batch_size="16" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_total_limit="3" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-large-lv60" \
|
||||
--fp16 \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--preprocessing_num_workers="32" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor
|
||||
4
examples/research_projects/wav2vec2/requirements.txt
Normal file
4
examples/research_projects/wav2vec2/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
transformers
|
||||
datasets
|
||||
torch >= 1.5.0
|
||||
jiwer
|
||||
281
examples/research_projects/wav2vec2/run_asr.py
Executable file
281
examples/research_projects/wav2vec2/run_asr.py
Executable file
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
import soundfile as sf
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Processor,
|
||||
is_apex_available,
|
||||
)
|
||||
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
dataset_name: str = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_split_name: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorCTCWithPadding:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received.
|
||||
Args:
|
||||
processor (:class:`~transformers.Wav2Vec2Processor`)
|
||||
The processor used for proccessing the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||
max_length_labels (:obj:`int`, `optional`):
|
||||
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
processor: Wav2Vec2Processor
|
||||
padding: Union[bool, str] = True
|
||||
max_length: Optional[int] = None
|
||||
max_length_labels: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
pad_to_multiple_of_labels: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lenghts and need
|
||||
# different padding methods
|
||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
batch = self.processor.pad(
|
||||
input_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length_labels,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# replace padding with -100 to ignore loss correctly
|
||||
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class CTCTrainer(Trainer):
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to train.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
if model.module.config.ctc_loss_reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif model.module.config.ctc_loss_reduction == "sum":
|
||||
loss = loss.sum() / (inputs["labels"] >= 0).sum()
|
||||
else:
|
||||
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
|
||||
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
if self.use_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.deepspeed:
|
||||
self.deepspeed.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
return loss.detach()
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
|
||||
train_dataset = datasets.load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
|
||||
)
|
||||
val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation")
|
||||
|
||||
wer_metric = datasets.load_metric("wer")
|
||||
|
||||
def map_to_array(batch):
|
||||
speech_array, sampling_rate = sf.read(batch["file"])
|
||||
batch["speech"] = speech_array
|
||||
batch["sampling_rate"] = sampling_rate
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(map_to_array, remove_columns=["file"])
|
||||
val_dataset = val_dataset.map(map_to_array, remove_columns=["file"])
|
||||
|
||||
def prepare_dataset(batch):
|
||||
# check that all files have the correct sampling rate
|
||||
assert (
|
||||
len(set(batch["sampling_rate"])) == 1
|
||||
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
||||
|
||||
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
||||
with processor.as_target_processor():
|
||||
batch["labels"] = processor(batch["text"]).input_ids
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_dataset,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
)
|
||||
val_dataset = val_dataset.map(
|
||||
prepare_dataset,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
)
|
||||
|
||||
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
||||
|
||||
def compute_metrics(pred):
|
||||
pred_logits = pred.predictions
|
||||
pred_ids = np.argmax(pred_logits, axis=-1)
|
||||
|
||||
pred.label_ids[pred.label_ids == -100] = 0
|
||||
|
||||
pred_str = processor.batch_decode(pred_ids)
|
||||
# we do not want to group tokens when computing the metrics
|
||||
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
||||
|
||||
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
||||
|
||||
return {"wer": wer}
|
||||
|
||||
if model_args.freeze_feature_extractor:
|
||||
model.freeze_feature_extractor()
|
||||
|
||||
trainer = CTCTrainer(
|
||||
model=model,
|
||||
data_collator=data_collator,
|
||||
args=training_args,
|
||||
compute_metrics=compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
tokenizer=processor.feature_extractor,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -92,6 +92,33 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
|
||||
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
|
||||
False`` corresponds to applying layer norm after the attention layer.
|
||||
freeze_feat_extract_train (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to freeze the weights of the feature extractor when training.
|
||||
apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see
|
||||
`SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
<https://arxiv.org/abs/1904.08779>`__.
|
||||
mask_time_prob (:obj:`float`, `optional`, defaults to 0.05):
|
||||
Propability of each feature vector along the time axis to be chosen as the start of the vector span to be
|
||||
masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_time_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the time axis.
|
||||
mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Propability of each feature vector along the feature axis to be chosen as the start of the vector span to
|
||||
be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the feature axis.
|
||||
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
|
||||
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
|
||||
instance of :class:`~transformers.Wav2Vec2ForCTC`.
|
||||
ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.Wav2Vec2ForCTC`.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -116,12 +143,15 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
|
||||
attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
|
||||
hidden_dropout=0.1,
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
feat_proj_dropout=0.1,
|
||||
final_dropout=0.1,
|
||||
layerdrop=0.1,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
feat_extract_norm="group",
|
||||
feat_extract_dropout=0.0,
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
||||
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
||||
@@ -130,6 +160,15 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
num_conv_pos_embeddings=128,
|
||||
num_conv_pos_embedding_groups=16,
|
||||
do_stable_layer_norm=False,
|
||||
freeze_feat_extract_train=True,
|
||||
apply_spec_augment=True,
|
||||
mask_time_prob=0.05,
|
||||
mask_time_length=10,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
@@ -138,7 +177,6 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_dropout = feat_extract_dropout
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = list(conv_dim)
|
||||
self.conv_stride = list(conv_stride)
|
||||
@@ -151,12 +189,18 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.feat_proj_dropout = feat_proj_dropout
|
||||
self.final_dropout = final_dropout
|
||||
self.layerdrop = layerdrop
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.freeze_feat_extract_train = freeze_feat_extract_train
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
@@ -169,3 +213,14 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
|
||||
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
||||
)
|
||||
|
||||
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
||||
self.apply_spec_augment = apply_spec_augment
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.mask_feature_prob = mask_feature_prob
|
||||
self.mask_feature_length = mask_feature_length
|
||||
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
@@ -20,26 +20,27 @@ import argparse
|
||||
import fairseq
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAPPING = {
|
||||
"post_extract_proj": "wav2vec2.feature_projection.projection",
|
||||
"encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm",
|
||||
"fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "wav2vec2.encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm",
|
||||
"post_extract_proj": "feature_projection.projection",
|
||||
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
||||
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +48,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
assert (
|
||||
hf_shape == value.shape
|
||||
), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||
@@ -59,26 +64,32 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.")
|
||||
else:
|
||||
hf_pointer.data = value
|
||||
|
||||
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def recursively_load_weights(fairseq_model, hf_model):
|
||||
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
if "conv_layers" in name:
|
||||
load_conv_layer(
|
||||
name,
|
||||
value,
|
||||
hf_model.wav2vec2.feature_extractor,
|
||||
feature_extractor,
|
||||
unused_weights,
|
||||
hf_model.config.feat_extract_norm == "group",
|
||||
)
|
||||
is_used = True
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
if key in name:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
@@ -92,6 +103,8 @@ def recursively_load_weights(fairseq_model, hf_model):
|
||||
weight_type = "weight"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
continue
|
||||
if not is_used:
|
||||
@@ -137,18 +150,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None):
|
||||
def convert_wav2vec2_checkpoint(
|
||||
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
|
||||
if config_path is not None:
|
||||
config = Wav2Vec2Config.from_pretrained(config_path)
|
||||
else:
|
||||
config = Wav2Vec2Config()
|
||||
|
||||
if is_finetuned:
|
||||
hf_wav2vec = Wav2Vec2ForCTC(config)
|
||||
else:
|
||||
hf_wav2vec = Wav2Vec2Model(config)
|
||||
|
||||
if is_finetuned:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||
)
|
||||
else:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
|
||||
|
||||
model = model[0].eval()
|
||||
|
||||
recursively_load_weights(model, hf_wav2vec)
|
||||
recursively_load_weights(model, hf_wav2vec, is_finetuned)
|
||||
|
||||
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
@@ -158,5 +185,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument(
|
||||
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path)
|
||||
convert_wav2vec2_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
|
||||
)
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Wav2Vec2 model. """
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
@@ -44,6 +44,77 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
min_masks: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_length: size of the mask
|
||||
min_masks: minimum number of masked spans
|
||||
|
||||
Adapted from `fairseq's data_utils.py
|
||||
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
|
||||
"""
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
@@ -57,12 +128,10 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||
stride=config.conv_stride[layer_id],
|
||||
bias=config.conv_bias,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -80,13 +149,11 @@ class Wav2Vec2LayerNormConvLayer(nn.Module):
|
||||
stride=config.conv_stride[layer_id],
|
||||
bias=config.conv_bias,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-2, -1)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@@ -109,14 +176,12 @@ class Wav2Vec2GroupNormConvLayer(nn.Module):
|
||||
stride=config.conv_stride[layer_id],
|
||||
bias=config.conv_bias,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
@@ -178,6 +243,10 @@ class Wav2Vec2FeatureExtractor(nn.Module):
|
||||
)
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, input_values):
|
||||
hidden_states = input_values[:, None]
|
||||
for conv_layer in self.conv_layers:
|
||||
@@ -191,7 +260,7 @@ class Wav2Vec2FeatureProjection(nn.Module):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
||||
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@@ -346,7 +415,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
class Wav2Vec2FeedForward(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.intermediate_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
||||
|
||||
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
@@ -355,7 +424,7 @@ class Wav2Vec2FeedForward(nn.Module):
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.intermediate_dense(hidden_states)
|
||||
@@ -381,10 +450,10 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
||||
self.attention = Wav2Vec2Attention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.hidden_dropout_prob,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=False,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -401,7 +470,12 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
||||
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
return hidden_states, attn_weights
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
@@ -410,10 +484,10 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
self.attention = Wav2Vec2Attention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.hidden_dropout_prob,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=False,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -428,7 +502,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
hidden_states = attn_residual + hidden_states
|
||||
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||
|
||||
return hidden_states, attn_weights
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Wav2Vec2Encoder(nn.Module):
|
||||
@@ -437,8 +516,7 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
self.config = config
|
||||
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# IMPORTANT: the param for dropout is probs wrong
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
@@ -471,12 +549,32 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = np.random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
# create gradient checkpointing function
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -496,8 +594,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
self.config = config
|
||||
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# IMPORTANT: the param for dropout is probs wrong
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layers = nn.ModuleList(
|
||||
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
@@ -531,12 +628,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = np.random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
# create gradient checkpointing function
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
@@ -584,7 +701,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||
return (input_length - kernel_size) // stride + 1
|
||||
|
||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
@@ -659,6 +776,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
|
||||
self.feature_projection = Wav2Vec2FeatureProjection(config)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
if config.do_stable_layer_norm:
|
||||
self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
|
||||
else:
|
||||
@@ -726,6 +845,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states)
|
||||
|
||||
if self.config.apply_spec_augment and self.training:
|
||||
batch_size, sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
# apply SpecAugment along time axis
|
||||
if self.config.mask_time_prob > 0:
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
(batch_size, sequence_length),
|
||||
self.config.mask_time_prob,
|
||||
self.config.mask_time_length,
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
# apply SpecAugment along feature axis
|
||||
if self.config.mask_feature_prob > 0:
|
||||
mask_feature_indices = _compute_mask_indices(
|
||||
(batch_size, hidden_size),
|
||||
self.config.mask_feature_prob,
|
||||
self.config.mask_feature_length,
|
||||
)
|
||||
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@@ -756,7 +899,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
@@ -773,7 +916,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
TODO(PVP): Fill out when adding training
|
||||
|
||||
Returns:
|
||||
@@ -831,11 +974,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameter
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -848,8 +998,11 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
TODO(PVP): Fill out when adding training
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_length)`, `optional`):
|
||||
Labels for connectionist temporal classification. Note that ``target_length`` has to be smaller or equal to
|
||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -873,9 +1026,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
>>> labels = processor(transcription, return_tensors="pt").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@@ -893,8 +1055,38 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
# retrieve loss input_lengths from attention_mask
|
||||
attention_mask = (
|
||||
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
||||
)
|
||||
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
|
||||
# assuming that padded tokens are filled with -100
|
||||
# when not being attended to
|
||||
labels_mask = labels >= 0
|
||||
target_lengths = labels_mask.sum(-1)
|
||||
flattened_targets = labels.masked_select(labels_mask)
|
||||
|
||||
log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = F.ctc_loss(
|
||||
log_probs,
|
||||
flattened_targets,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.config.pad_token_id,
|
||||
reduction=self.config.ctc_loss_reduction,
|
||||
zero_infinity=self.config.ctc_zero_infinity,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
return CausalLMOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
@@ -115,6 +115,16 @@ class Wav2Vec2Processor:
|
||||
"""
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
def pad(self, *args, **kwargs):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
||||
:meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context
|
||||
:meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to
|
||||
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the
|
||||
above two methods for more information.
|
||||
"""
|
||||
return self.current_processor.pad(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Wav2Vec2CTCTokenizer's
|
||||
|
||||
@@ -509,11 +509,18 @@ class Trainer:
|
||||
|
||||
# Build the sampler.
|
||||
if self.args.group_by_length:
|
||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||
if num_processes <= 1:
|
||||
return LengthGroupedSampler(self.train_dataset, self.args.train_batch_size)
|
||||
return LengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
|
||||
)
|
||||
else:
|
||||
return DistributedLengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index
|
||||
self.train_dataset,
|
||||
self.args.train_batch_size,
|
||||
num_replicas=num_processes,
|
||||
rank=process_index,
|
||||
model_input_name=model_input_name,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -452,16 +452,23 @@ class LengthGroupedSampler(Sampler):
|
||||
keeping a bit of randomness.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
lengths: Optional[List[int]] = None,
|
||||
model_input_name: Optional[str] = None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
f"'{self.model_input_name}' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __len__(self):
|
||||
@@ -487,6 +494,7 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
lengths: Optional[List[int]] = None,
|
||||
model_input_name: Optional[str] = None,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
@@ -513,14 +521,15 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.seed = seed
|
||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
f"'{self.model_input_name}' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
|
||||
@@ -735,6 +735,8 @@ class ModelTesterMixin:
|
||||
inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
outputs = model(**inputs)
|
||||
|
||||
print(outputs)
|
||||
output = outputs[0]
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import math
|
||||
import unittest
|
||||
|
||||
from tests.test_modeling_common import floats_tensor, random_attention_mask
|
||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||
|
||||
@@ -30,6 +30,7 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor
|
||||
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@@ -128,9 +129,7 @@ class Wav2Vec2ModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# Not sure how to make this test pass at the moment. Batched input yields
|
||||
# same results as official fairseq implementation, but gives different results
|
||||
# depending on whether batched input is used or not
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
model = Wav2Vec2Model(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -155,6 +154,62 @@ class Wav2Vec2ModelTester:
|
||||
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
||||
|
||||
def check_ctc_loss(self, config, input_values, *args):
|
||||
model = Wav2Vec2ForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
# make sure that dropout is disabled
|
||||
model.eval()
|
||||
|
||||
input_values = input_values[:3]
|
||||
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||
labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size)
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
attention_mask[i, input_lengths[i] :] = 0.0
|
||||
|
||||
model.config.ctc_loss_reduction = "sum"
|
||||
sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
model.config.ctc_loss_reduction = "mean"
|
||||
mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss
|
||||
|
||||
self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3)
|
||||
|
||||
def check_training(self, config, input_values, *args):
|
||||
config.ctc_zero_infinity = True
|
||||
model = Wav2Vec2ForCTC(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# freeze feature encoder
|
||||
model.freeze_feature_extractor()
|
||||
|
||||
input_values = input_values[:3]
|
||||
|
||||
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||
max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths))
|
||||
labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size)
|
||||
|
||||
# pad input
|
||||
for i in range(len(input_lengths)):
|
||||
input_values[i, input_lengths[i] :] = 0.0
|
||||
|
||||
if max_length_labels[i] < labels.shape[-1]:
|
||||
# it's important that we make sure that target lenghts are at least
|
||||
# one shorter than logit lenghts to prevent -inf
|
||||
labels[i, max_length_labels[i] - 1 :] = -100
|
||||
|
||||
loss = model(input_values, labels=labels).loss
|
||||
self.parent.assertFalse(torch.isinf(loss).item())
|
||||
|
||||
loss.backward()
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||
@@ -165,6 +220,7 @@ class Wav2Vec2ModelTester:
|
||||
class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
(
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
)
|
||||
@@ -186,6 +242,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
@@ -205,6 +269,46 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
# set layer drop to 0
|
||||
model.config.layerdrop = 0.0
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
|
||||
input_lengths = torch.tensor(
|
||||
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
|
||||
)
|
||||
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
|
||||
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Encoder-/Decoder-only models
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
attentions = outputs.attentions[0]
|
||||
|
||||
hidden_states.retain_grad()
|
||||
attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -213,7 +317,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if "conv.weight" in name:
|
||||
if "conv.weight" in name or "masked_spec_embed" in name:
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||
@@ -233,7 +337,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
|
||||
all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@@ -255,6 +359,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
|
||||
def test_train(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_training(*config_and_inputs)
|
||||
|
||||
# Wav2Vec2 has no inputs_embeds
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
@@ -274,6 +386,46 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_model_common_attributes(self):
|
||||
pass
|
||||
|
||||
def test_retain_grad_hidden_states_attentions(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
# set layer drop to 0
|
||||
model.config.layerdrop = 0.0
|
||||
|
||||
input_values = inputs_dict["input_values"]
|
||||
|
||||
input_lengths = torch.tensor(
|
||||
[input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device
|
||||
)
|
||||
output_lengths = model._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size)
|
||||
inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"])
|
||||
inputs_dict["labels"] = labels
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
output = outputs[0]
|
||||
|
||||
# Encoder-/Decoder-only models
|
||||
hidden_states = outputs.hidden_states[0]
|
||||
attentions = outputs.attentions[0]
|
||||
|
||||
hidden_states.retain_grad()
|
||||
attentions.retain_grad()
|
||||
|
||||
output.flatten()[0].backward(retain_graph=True)
|
||||
|
||||
self.assertIsNotNone(hidden_states.grad)
|
||||
self.assertIsNotNone(attentions.grad)
|
||||
|
||||
def test_initialization(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -282,7 +434,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = model_class(config=configs_no_init)
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
if "conv.weight" in name:
|
||||
if "conv.weight" in name or "masked_spec_embed" in name:
|
||||
self.assertTrue(
|
||||
-1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0,
|
||||
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class),
|
||||
@@ -300,6 +452,59 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
def test_compute_mask_indices(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertIn(
|
||||
int(batch_sum),
|
||||
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
|
||||
)
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertIn(
|
||||
int(batch_sum),
|
||||
list(
|
||||
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@require_datasets
|
||||
|
||||
Reference in New Issue
Block a user