From 0234de8418b253e843dda0a18a3e18476de52781 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 1 Mar 2021 12:13:17 +0300 Subject: [PATCH] Add Fine-Tuning for Wav2Vec2 (#10145) * add encode labels function to tokenizer * start adding finetuning * init dropout * upload * correct convert script * apply changes * fix second typo * make first dummy training run * adapt convert script * push confg for comparison * remove conf * finish training * adapt data collator * add research folder * update according to fairseq feedback * some minor corrections * refactor masking indices a bit * some minor changes * clean tokenizer * finish clean-up * remove previous logic * update run script * correct training * finish changes * finish model * correct bug * fix training a bit more * add some tests * finish gradient checkpointing * finish example * correct gradient checkpointing * improve tokenization method * revert changes in tokenizer * revert general change * adapt fine-tuning * update * save intermediate test * Update README.md * finish finetuning * delete conversion script * Update src/transformers/models/wav2vec2/configuration_wav2vec2.py * Update src/transformers/models/wav2vec2/processing_wav2vec2.py Co-authored-by: Lysandre Debut * finish wav2vec2 script * finish wav2vec2 fine-tuning * finalize test * correct test * adapt tests * finish * remove test file Co-authored-by: Lysandre Debut --- docs/source/model_doc/wav2vec2.rst | 2 +- examples/research_projects/wav2vec2/README.md | 8 + .../wav2vec2/finetune_base_100.sh | 21 ++ .../wav2vec2/finetune_large_lv60_100.sh | 21 ++ .../wav2vec2/requirements.txt | 4 + .../research_projects/wav2vec2/run_asr.py | 281 ++++++++++++++++++ .../models/wav2vec2/configuration_wav2vec2.py | 67 ++++- ..._original_pytorch_checkpoint_to_pytorch.py | 81 +++-- .../models/wav2vec2/modeling_wav2vec2.py | 266 ++++++++++++++--- .../models/wav2vec2/processing_wav2vec2.py | 10 + src/transformers/trainer.py | 11 +- src/transformers/trainer_pt_utils.py | 23 +- tests/test_modeling_common.py | 2 + tests/test_modeling_wav2vec2.py | 219 +++++++++++++- 14 files changed, 932 insertions(+), 84 deletions(-) create mode 100644 examples/research_projects/wav2vec2/README.md create mode 100755 examples/research_projects/wav2vec2/finetune_base_100.sh create mode 100755 examples/research_projects/wav2vec2/finetune_large_lv60_100.sh create mode 100644 examples/research_projects/wav2vec2/requirements.txt create mode 100755 examples/research_projects/wav2vec2/run_asr.py diff --git a/docs/source/model_doc/wav2vec2.rst b/docs/source/model_doc/wav2vec2.rst index 7f59639581..63b851afb8 100644 --- a/docs/source/model_doc/wav2vec2.rst +++ b/docs/source/model_doc/wav2vec2.rst @@ -62,7 +62,7 @@ Wav2Vec2Processor ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.Wav2Vec2Processor - :members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor + :members: __call__, pad, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor Wav2Vec2Model diff --git a/examples/research_projects/wav2vec2/README.md b/examples/research_projects/wav2vec2/README.md new file mode 100644 index 0000000000..23741d0060 --- /dev/null +++ b/examples/research_projects/wav2vec2/README.md @@ -0,0 +1,8 @@ +## Fine-tuning Wav2Vec2 + +The `run_training.py` script allows one to finetune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2). + +This finetuning script can also be run as a google colab [TODO: here]( ). + +The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten). +Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag. diff --git a/examples/research_projects/wav2vec2/finetune_base_100.sh b/examples/research_projects/wav2vec2/finetune_base_100.sh new file mode 100755 index 0000000000..8002dd8123 --- /dev/null +++ b/examples/research_projects/wav2vec2/finetune_base_100.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +python run_asr.py \ +--output_dir="./wav2vec2-base-100h" \ +--num_train_epochs="30" \ +--per_device_train_batch_size="32" \ +--per_device_eval_batch_size="32" \ +--evaluation_strategy="steps" \ +--save_total_limit="3" \ +--save_steps="500" \ +--eval_steps="100" \ +--logging_steps="50" \ +--learning_rate="5e-4" \ +--warmup_steps="3000" \ +--model_name_or_path="facebook/wav2vec2-base" \ +--fp16 \ +--dataset_name="librispeech_asr" \ +--dataset_config_name="clean" \ +--train_split_name="train.100" \ +--preprocessing_num_workers="32" \ +--group_by_length \ +--freeze_feature_extractor diff --git a/examples/research_projects/wav2vec2/finetune_large_lv60_100.sh b/examples/research_projects/wav2vec2/finetune_large_lv60_100.sh new file mode 100755 index 0000000000..3d2423df97 --- /dev/null +++ b/examples/research_projects/wav2vec2/finetune_large_lv60_100.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +python run_asr.py \ +--output_dir="./wav2vec2-large-lv60-100h" \ +--num_train_epochs="30" \ +--per_device_train_batch_size="16" \ +--per_device_eval_batch_size="16" \ +--evaluation_strategy="steps" \ +--save_total_limit="3" \ +--save_steps="500" \ +--eval_steps="100" \ +--logging_steps="50" \ +--learning_rate="5e-4" \ +--warmup_steps="3000" \ +--model_name_or_path="facebook/wav2vec2-large-lv60" \ +--fp16 \ +--dataset_name="librispeech_asr" \ +--dataset_config_name="clean" \ +--train_split_name="train.100" \ +--preprocessing_num_workers="32" \ +--group_by_length \ +--freeze_feature_extractor diff --git a/examples/research_projects/wav2vec2/requirements.txt b/examples/research_projects/wav2vec2/requirements.txt new file mode 100644 index 0000000000..9c360ffdd5 --- /dev/null +++ b/examples/research_projects/wav2vec2/requirements.txt @@ -0,0 +1,4 @@ +transformers +datasets +torch >= 1.5.0 +jiwer diff --git a/examples/research_projects/wav2vec2/run_asr.py b/examples/research_projects/wav2vec2/run_asr.py new file mode 100755 index 0000000000..00c64840e2 --- /dev/null +++ b/examples/research_projects/wav2vec2/run_asr.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import datasets +import numpy as np +import torch +import torch.nn as nn +from packaging import version + +import soundfile as sf +from transformers import ( + HfArgumentParser, + Trainer, + TrainingArguments, + Wav2Vec2ForCTC, + Wav2Vec2Processor, + is_apex_available, +) + + +if is_apex_available(): + from apex import amp + + +if version.parse(torch.__version__) >= version.parse("1.6"): + _is_native_amp_available = True + from torch.cuda.amp import autocast + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + freeze_feature_extractor: Optional[bool] = field( + default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."} + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + dataset_name: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_split_name: Optional[str] = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + + +@dataclass +class DataCollatorCTCWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor (:class:`~transformers.Wav2Vec2Processor`) + The processor used for proccessing the data. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding index) + among: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + max_length (:obj:`int`, `optional`): + Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). + max_length_labels (:obj:`int`, `optional`): + Maximum length of the ``labels`` returned list and optionally padding length (see above). + pad_to_multiple_of (:obj:`int`, `optional`): + If set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= + 7.5 (Volta). + """ + + processor: Wav2Vec2Processor + padding: Union[bool, str] = True + max_length: Optional[int] = None + max_length_labels: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + pad_to_multiple_of_labels: Optional[int] = None + + def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lenghts and need + # different padding methods + input_features = [{"input_values": feature["input_values"]} for feature in features] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + batch = self.processor.pad( + input_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors="pt", + ) + with self.processor.as_target_processor(): + labels_batch = self.processor.pad( + label_features, + padding=self.padding, + max_length=self.max_length_labels, + pad_to_multiple_of=self.pad_to_multiple_of_labels, + return_tensors="pt", + ) + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) + + batch["labels"] = labels + + return batch + + +class CTCTrainer(Trainer): + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (:obj:`nn.Module`): + The model to train. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + + Return: + :obj:`torch.Tensor`: The tensor with training loss on this batch. + """ + + model.train() + inputs = self._prepare_inputs(inputs) + + if self.use_amp: + with autocast(): + loss = self.compute_loss(model, inputs) + else: + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + if model.module.config.ctc_loss_reduction == "mean": + loss = loss.mean() + elif model.module.config.ctc_loss_reduction == "sum": + loss = loss.sum() / (inputs["labels"] >= 0).sum() + else: + raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']") + + if self.args.gradient_accumulation_steps > 1: + loss = loss / self.args.gradient_accumulation_steps + + if self.use_amp: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + elif self.deepspeed: + self.deepspeed.backward(loss) + else: + loss.backward() + + return loss.detach() + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + + train_dataset = datasets.load_dataset( + data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name + ) + val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation") + + wer_metric = datasets.load_metric("wer") + + def map_to_array(batch): + speech_array, sampling_rate = sf.read(batch["file"]) + batch["speech"] = speech_array + batch["sampling_rate"] = sampling_rate + return batch + + train_dataset = train_dataset.map(map_to_array, remove_columns=["file"]) + val_dataset = val_dataset.map(map_to_array, remove_columns=["file"]) + + def prepare_dataset(batch): + # check that all files have the correct sampling rate + assert ( + len(set(batch["sampling_rate"])) == 1 + ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}." + + batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values + with processor.as_target_processor(): + batch["labels"] = processor(batch["text"]).input_ids + return batch + + train_dataset = train_dataset.map( + prepare_dataset, + batch_size=training_args.per_device_train_batch_size, + batched=True, + num_proc=data_args.preprocessing_num_workers, + ) + val_dataset = val_dataset.map( + prepare_dataset, + batch_size=training_args.per_device_train_batch_size, + batched=True, + num_proc=data_args.preprocessing_num_workers, + ) + + data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) + + def compute_metrics(pred): + pred_logits = pred.predictions + pred_ids = np.argmax(pred_logits, axis=-1) + + pred.label_ids[pred.label_ids == -100] = 0 + + pred_str = processor.batch_decode(pred_ids) + # we do not want to group tokens when computing the metrics + label_str = processor.batch_decode(pred.label_ids, group_tokens=False) + + wer = wer_metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + if model_args.freeze_feature_extractor: + model.freeze_feature_extractor() + + trainer = CTCTrainer( + model=model, + data_collator=data_collator, + args=training_args, + compute_metrics=compute_metrics, + train_dataset=train_dataset, + eval_dataset=val_dataset, + tokenizer=processor.feature_extractor, + ) + + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 4ee10a8960..4fc4987776 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -92,6 +92,33 @@ class Wav2Vec2Config(PretrainedConfig): Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is False`` corresponds to applying layer norm after the attention layer. + freeze_feat_extract_train (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to freeze the weights of the feature extractor when training. + apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see + `SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition + `__. + mask_time_prob (:obj:`float`, `optional`, defaults to 0.05): + Propability of each feature vector along the time axis to be chosen as the start of the vector span to be + masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be + masked along the time axis. This is only relevant if ``apply_spec_augment is True``. + mask_time_length (:obj:`int`, `optional`, defaults to 10): + Length of vector span along the time axis. + mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0): + Propability of each feature vector along the feature axis to be chosen as the start of the vector span to + be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be + masked along the time axis. This is only relevant if ``apply_spec_augment is True``. + mask_feature_length (:obj:`int`, `optional`, defaults to 10): + Length of vector span along the feature axis. + ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`): + Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an + instance of :class:`~transformers.Wav2Vec2ForCTC`. + ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses + mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an + instance of :class:`~transformers.Wav2Vec2ForCTC`. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -116,12 +143,15 @@ class Wav2Vec2Config(PretrainedConfig): num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", - hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train - attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.1, + final_dropout=0.1, + layerdrop=0.1, initializer_range=0.02, layer_norm_eps=1e-5, feat_extract_norm="group", - feat_extract_dropout=0.0, feat_extract_activation="gelu", conv_dim=(512, 512, 512, 512, 512, 512, 512), conv_stride=(5, 2, 2, 2, 2, 2, 2), @@ -130,6 +160,15 @@ class Wav2Vec2Config(PretrainedConfig): num_conv_pos_embeddings=128, num_conv_pos_embedding_groups=16, do_stable_layer_norm=False, + freeze_feat_extract_train=True, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_feature_prob=0.0, + mask_feature_length=10, + ctc_loss_reduction="sum", + ctc_zero_infinity=False, + gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -138,7 +177,6 @@ class Wav2Vec2Config(PretrainedConfig): super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) self.hidden_size = hidden_size self.feat_extract_norm = feat_extract_norm - self.feat_extract_dropout = feat_extract_dropout self.feat_extract_activation = feat_extract_activation self.conv_dim = list(conv_dim) self.conv_stride = list(conv_stride) @@ -151,12 +189,18 @@ class Wav2Vec2Config(PretrainedConfig): self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.num_attention_heads = num_attention_heads - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm + self.freeze_feat_extract_train = freeze_feat_extract_train + self.gradient_checkpointing = gradient_checkpointing if ( (len(self.conv_stride) != self.num_feat_extract_layers) @@ -169,3 +213,14 @@ class Wav2Vec2Config(PretrainedConfig): f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)" f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`." ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity diff --git a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py index cbe74b8c7b..7140d33ea9 100644 --- a/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -20,26 +20,27 @@ import argparse import fairseq import torch -from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging +from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, logging logging.set_verbosity_info() logger = logging.get_logger(__name__) MAPPING = { - "post_extract_proj": "wav2vec2.feature_projection.projection", - "encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv", - "self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj", - "self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj", - "self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj", - "self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj", - "self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm", - "fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense", - "fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense", - "final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm", - "encoder.layer_norm": "wav2vec2.encoder.layer_norm", - "w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm", + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", } @@ -47,7 +48,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): for attribute in key.split("."): hf_pointer = getattr(hf_pointer, attribute) - hf_shape = getattr(hf_pointer, weight_type).shape + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + assert ( hf_shape == value.shape ), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}" @@ -59,26 +64,32 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type): hf_pointer.weight_v.data = value elif weight_type == "bias": hf_pointer.bias.data = value - logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.") + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") -def recursively_load_weights(fairseq_model, hf_model): +def recursively_load_weights(fairseq_model, hf_model, is_finetuned): unused_weights = [] fairseq_dict = fairseq_model.state_dict() + feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor + for name, value in fairseq_dict.items(): is_used = False if "conv_layers" in name: load_conv_layer( name, value, - hf_model.wav2vec2.feature_extractor, + feature_extractor, unused_weights, hf_model.config.feat_extract_norm == "group", ) is_used = True else: for key, mapped_key in MAPPING.items(): + mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key if key in name: is_used = True if "*" in mapped_key: @@ -92,6 +103,8 @@ def recursively_load_weights(fairseq_model, hf_model): weight_type = "weight" elif "bias" in name: weight_type = "bias" + else: + weight_type = None set_recursively(hf_model, mapped_key, value, name, weight_type) continue if not is_used: @@ -137,18 +150,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro @torch.no_grad() -def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None): +def convert_wav2vec2_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True +): """ Copy/paste/tweak model's weights to transformers design. """ - hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config()) + if config_path is not None: + config = Wav2Vec2Config.from_pretrained(config_path) + else: + config = Wav2Vec2Config() + + if is_finetuned: + hf_wav2vec = Wav2Vec2ForCTC(config) + else: + hf_wav2vec = Wav2Vec2Model(config) + + if is_finetuned: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": dict_path} + ) + else: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) - model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( - [checkpoint_path], arg_overrides={"data": dict_path} - ) model = model[0].eval() - recursively_load_weights(model, hf_wav2vec) + recursively_load_weights(model, hf_wav2vec, is_finetuned) hf_wav2vec.save_pretrained(pytorch_dump_folder_path) @@ -158,5 +185,11 @@ if __name__ == "__main__": parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) args = parser.parse_args() - convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path) + convert_wav2vec2_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned + ) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 2322e5cce7..ba548dc3d8 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -14,10 +14,10 @@ # limitations under the License. """ PyTorch Wav2Vec2 model. """ - import warnings from typing import Optional, Tuple +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -44,6 +44,77 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [ ] +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.Tensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_length: size of the mask + min_masks: minimum number of masked spans + + Adapted from `fairseq's data_utils.py + `__. + """ + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + padding_mask = attention_mask.ne(1) if attention_mask is not None else None + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + lengths = np.full(num_mask, mask_length) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + class Wav2Vec2NoLayerNormConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -57,12 +128,10 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Module): stride=config.conv_stride[layer_id], bias=config.conv_bias, ) - self.dropout = nn.Dropout(config.feat_extract_dropout) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = self.conv(hidden_states) - hidden_states = self.dropout(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states @@ -80,13 +149,11 @@ class Wav2Vec2LayerNormConvLayer(nn.Module): stride=config.conv_stride[layer_id], bias=config.conv_bias, ) - self.dropout = nn.Dropout(config.feat_extract_dropout) self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): hidden_states = self.conv(hidden_states) - hidden_states = self.dropout(hidden_states) hidden_states = hidden_states.transpose(-2, -1) hidden_states = self.layer_norm(hidden_states) @@ -109,14 +176,12 @@ class Wav2Vec2GroupNormConvLayer(nn.Module): stride=config.conv_stride[layer_id], bias=config.conv_bias, ) - self.dropout = nn.Dropout(config.feat_extract_dropout) self.activation = ACT2FN[config.feat_extract_activation] self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) def forward(self, hidden_states): hidden_states = self.conv(hidden_states) - hidden_states = self.dropout(hidden_states) hidden_states = self.layer_norm(hidden_states) hidden_states = self.activation(hidden_states) return hidden_states @@ -178,6 +243,10 @@ class Wav2Vec2FeatureExtractor(nn.Module): ) self.conv_layers = nn.ModuleList(conv_layers) + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + def forward(self, input_values): hidden_states = input_values[:, None] for conv_layer in self.conv_layers: @@ -191,7 +260,7 @@ class Wav2Vec2FeatureProjection(nn.Module): super().__init__() self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) - self.dropout = nn.Dropout(config.feat_extract_dropout) + self.dropout = nn.Dropout(config.feat_proj_dropout) def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) @@ -346,7 +415,7 @@ class Wav2Vec2Attention(nn.Module): class Wav2Vec2FeedForward(nn.Module): def __init__(self, config): super().__init__() - self.intermediate_dropout = nn.Dropout(config.hidden_dropout_prob) + self.intermediate_dropout = nn.Dropout(config.activation_dropout) self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): @@ -355,7 +424,7 @@ class Wav2Vec2FeedForward(nn.Module): self.intermediate_act_fn = config.hidden_act self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.output_dropout = nn.Dropout(config.hidden_dropout_prob) + self.output_dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states): hidden_states = self.intermediate_dense(hidden_states) @@ -381,10 +450,10 @@ class Wav2Vec2EncoderLayer(nn.Module): self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, - dropout=config.hidden_dropout_prob, + dropout=config.attention_dropout, is_decoder=False, ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -401,7 +470,12 @@ class Wav2Vec2EncoderLayer(nn.Module): hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) - return hidden_states, attn_weights + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): @@ -410,10 +484,10 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, - dropout=config.hidden_dropout_prob, + dropout=config.attention_dropout, is_decoder=False, ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = Wav2Vec2FeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -428,7 +502,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) - return hidden_states, attn_weights + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs class Wav2Vec2Encoder(nn.Module): @@ -437,8 +516,7 @@ class Wav2Vec2Encoder(nn.Module): self.config = config self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - # IMPORTANT: the param for dropout is probs wrong - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) def forward( @@ -471,12 +549,32 @@ class Wav2Vec2Encoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states, attn_weights = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + if self.training and (dropout_probability < self.config.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] if output_attentions: - all_self_attentions = all_self_attentions + (attn_weights,) + all_self_attentions = all_self_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -496,8 +594,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): self.config = config self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - # IMPORTANT: the param for dropout is probs wrong - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList( [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) @@ -531,12 +628,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - hidden_states, attn_weights = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + if self.training and (dropout_probability < self.config.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + if getattr(self.config, "gradient_checkpointing", False) and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] if output_attentions: - all_self_attentions = all_self_attentions + (attn_weights,) + all_self_attentions = all_self_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) @@ -584,7 +701,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html - return torch.floor((input_length - kernel_size) / stride + 1) + return (input_length - kernel_size) // stride + 1 for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) @@ -659,6 +776,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): self.feature_extractor = Wav2Vec2FeatureExtractor(config) self.feature_projection = Wav2Vec2FeatureProjection(config) + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + if config.do_stable_layer_norm: self.encoder = Wav2Vec2EncoderStableLayerNorm(config) else: @@ -726,6 +845,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): hidden_states = self.feature_projection(hidden_states) + if self.config.apply_spec_augment and self.training: + batch_size, sequence_length, hidden_size = hidden_states.size() + + # apply SpecAugment along time axis + if self.config.mask_time_prob > 0: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + self.config.mask_time_prob, + self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=2, + ) + hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) + + # apply SpecAugment along feature axis + if self.config.mask_feature_prob > 0: + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + self.config.mask_feature_prob, + self.config.mask_feature_length, + ) + mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) + hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 + encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, @@ -756,7 +899,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): ) self.wav2vec2 = Wav2Vec2Model(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.final_dropout) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.init_weights() @@ -773,7 +916,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): labels=None, ): r""" - labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): TODO(PVP): Fill out when adding training Returns: @@ -831,11 +974,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): super().__init__(config) self.wav2vec2 = Wav2Vec2Model(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.dropout = nn.Dropout(config.final_dropout) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.init_weights() + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature extractor so that its parameter + will not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + @add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -848,8 +998,11 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): labels=None, ): r""" - labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - TODO(PVP): Fill out when adding training + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_length)`, `optional`): + Labels for connectionist temporal classification. Note that ``target_length`` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size - + 1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., + config.vocab_size - 1]``. Returns: @@ -873,9 +1026,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1 >>> logits = model(input_values).logits - >>> predicted_ids = torch.argmax(logits, dim=-1) + >>> transcription = processor.decode(predicted_ids[0]) + + >>> # compute loss + >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" + + >>> # wrap processor as target processor to encode labels + >>> with processor.as_target_processor(): + >>> labels = processor(transcription, return_tensors="pt").input_ids + + >>> loss = model(input_values, labels=labels).loss """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -893,8 +1055,38 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = F.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + if not return_dict: output = (logits,) + outputs[1:] - return output + return ((loss,) + output) if loss is not None else output - return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) diff --git a/src/transformers/models/wav2vec2/processing_wav2vec2.py b/src/transformers/models/wav2vec2/processing_wav2vec2.py index 3f1bd6b4b6..71202a2ff0 100644 --- a/src/transformers/models/wav2vec2/processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/processing_wav2vec2.py @@ -115,6 +115,16 @@ class Wav2Vec2Processor: """ return self.current_processor(*args, **kwargs) + def pad(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's + :meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context + :meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to + Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the + above two methods for more information. + """ + return self.current_processor.pad(*args, **kwargs) + def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to Wav2Vec2CTCTokenizer's diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 009c9bff10..504b852cfe 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -509,11 +509,18 @@ class Trainer: # Build the sampler. if self.args.group_by_length: + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None if num_processes <= 1: - return LengthGroupedSampler(self.train_dataset, self.args.train_batch_size) + return LengthGroupedSampler( + self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name + ) else: return DistributedLengthGroupedSampler( - self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index + self.train_dataset, + self.args.train_batch_size, + num_replicas=num_processes, + rank=process_index, + model_input_name=model_input_name, ) else: diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 9b76a23241..c94c58b8a7 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -452,16 +452,23 @@ class LengthGroupedSampler(Sampler): keeping a bit of randomness. """ - def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None): + def __init__( + self, + dataset: Dataset, + batch_size: int, + lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, + ): self.dataset = dataset self.batch_size = batch_size + self.model_input_name = model_input_name if model_input_name is not None else "input_ids" if lengths is None: - if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]: + if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]: raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " - "'input_ids' key." + f"'{self.model_input_name}' key." ) - lengths = [len(feature["input_ids"]) for feature in dataset] + lengths = [len(feature[self.model_input_name]) for feature in dataset] self.lengths = lengths def __len__(self): @@ -487,6 +494,7 @@ class DistributedLengthGroupedSampler(DistributedSampler): seed: int = 0, drop_last: bool = False, lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, ): if num_replicas is None: if not dist.is_available(): @@ -513,14 +521,15 @@ class DistributedLengthGroupedSampler(DistributedSampler): self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) self.total_size = self.num_samples * self.num_replicas self.seed = seed + self.model_input_name = model_input_name if model_input_name is not None else "input_ids" if lengths is None: - if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]: + if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]: raise ValueError( "Can only automatically infer lengths for datasets whose items are dictionaries with an " - "'input_ids' key." + f"'{self.model_input_name}' key." ) - lengths = [len(feature["input_ids"]) for feature in dataset] + lengths = [len(feature[self.model_input_name]) for feature in dataset] self.lengths = lengths def __iter__(self) -> Iterator: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2fe722a2bb..afded0b3fe 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -735,6 +735,8 @@ class ModelTesterMixin: inputs = self._prepare_for_class(inputs_dict, model_class) outputs = model(**inputs) + + print(outputs) output = outputs[0] if config.is_encoder_decoder: diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index cbf0583a3c..75b8795f78 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -18,7 +18,7 @@ import math import unittest -from tests.test_modeling_common import floats_tensor, random_attention_mask +from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from transformers import is_torch_available from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device @@ -30,6 +30,7 @@ if is_torch_available(): import torch from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Processor + from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices class Wav2Vec2ModelTester: @@ -128,9 +129,7 @@ class Wav2Vec2ModelTester: ) def create_and_check_batch_inference(self, config, input_values, *args): - # Not sure how to make this test pass at the moment. Batched input yields - # same results as official fairseq implementation, but gives different results - # depending on whether batched input is used or not + # test does not pass for models making use of `group_norm` # check: https://github.com/pytorch/fairseq/issues/3227 model = Wav2Vec2Model(config=config) model.to(torch_device) @@ -155,6 +154,62 @@ class Wav2Vec2ModelTester: batch_output = batch_outputs[i : i + 1, : output.shape[1]] self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3)) + def check_ctc_loss(self, config, input_values, *args): + model = Wav2Vec2ForCTC(config=config) + model.to(torch_device) + + # make sure that dropout is disabled + model.eval() + + input_values = input_values[:3] + attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool) + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], min(max_length_labels) - 1), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + attention_mask[i, input_lengths[i] :] = 0.0 + + model.config.ctc_loss_reduction = "sum" + sum_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss + + model.config.ctc_loss_reduction = "mean" + mean_loss = model(input_values, attention_mask=attention_mask, labels=labels).loss + + self.parent.assertTrue(abs(labels.shape[0] * labels.shape[1] * mean_loss.item() - sum_loss.item()) < 1e-3) + + def check_training(self, config, input_values, *args): + config.ctc_zero_infinity = True + model = Wav2Vec2ForCTC(config=config) + model.to(torch_device) + model.train() + + # freeze feature encoder + model.freeze_feature_extractor() + + input_values = input_values[:3] + + input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]] + max_length_labels = model._get_feat_extract_output_lengths(torch.tensor(input_lengths)) + labels = ids_tensor((input_values.shape[0], max(max_length_labels) - 2), model.config.vocab_size) + + # pad input + for i in range(len(input_lengths)): + input_values[i, input_lengths[i] :] = 0.0 + + if max_length_labels[i] < labels.shape[-1]: + # it's important that we make sure that target lenghts are at least + # one shorter than logit lenghts to prevent -inf + labels[i, max_length_labels[i] - 1 :] = -100 + + loss = model(input_values, labels=labels).loss + self.parent.assertFalse(torch.isinf(loss).item()) + + loss.backward() + def prepare_config_and_inputs_for_common(self): config, input_values, attention_mask = self.prepare_config_and_inputs() inputs_dict = {"input_values": input_values, "attention_mask": attention_mask} @@ -165,6 +220,7 @@ class Wav2Vec2ModelTester: class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( + Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM, ) @@ -186,6 +242,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_loss(*config_and_inputs) + + def test_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_training(*config_and_inputs) + # Wav2Vec2 has no inputs_embeds def test_inputs_embeds(self): pass @@ -205,6 +269,46 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): def test_model_common_attributes(self): pass + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + # set layer drop to 0 + model.config.layerdrop = 0.0 + + input_values = inputs_dict["input_values"] + + input_lengths = torch.tensor( + [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device + ) + output_lengths = model._get_feat_extract_output_lengths(input_lengths) + + labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size) + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"]) + inputs_dict["labels"] = labels + + outputs = model(**inputs_dict) + + output = outputs[0] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + attentions = outputs.attentions[0] + + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -213,7 +317,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): model = model_class(config=configs_no_init) for name, param in model.named_parameters(): if param.requires_grad: - if "conv.weight" in name: + if "conv.weight" in name or "masked_spec_embed" in name: self.assertTrue( -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, msg="Parameter {} of model {} seems not properly initialized".format(name, model_class), @@ -233,7 +337,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): @require_torch class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else () + all_model_classes = (Wav2Vec2ForCTC, Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else () test_pruning = False test_headmasking = False test_torchscript = False @@ -255,6 +359,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_batch_inference(*config_and_inputs) + def test_ctc_loss_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_ctc_loss(*config_and_inputs) + + def test_train(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_training(*config_and_inputs) + # Wav2Vec2 has no inputs_embeds def test_inputs_embeds(self): pass @@ -274,6 +386,46 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): def test_model_common_attributes(self): pass + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + # set layer drop to 0 + model.config.layerdrop = 0.0 + + input_values = inputs_dict["input_values"] + + input_lengths = torch.tensor( + [input_values.shape[1] for _ in range(input_values.shape[0])], dtype=torch.long, device=torch_device + ) + output_lengths = model._get_feat_extract_output_lengths(input_lengths) + + labels = ids_tensor((input_values.shape[0], output_lengths[0] - 2), self.model_tester.vocab_size) + inputs_dict["attention_mask"] = torch.ones_like(inputs_dict["attention_mask"]) + inputs_dict["labels"] = labels + + outputs = model(**inputs_dict) + + output = outputs[0] + + # Encoder-/Decoder-only models + hidden_states = outputs.hidden_states[0] + attentions = outputs.attentions[0] + + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -282,7 +434,7 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): model = model_class(config=configs_no_init) for name, param in model.named_parameters(): if param.requires_grad: - if "conv.weight" in name: + if "conv.weight" in name or "masked_spec_embed" in name: self.assertTrue( -1.0 <= ((param.data.mean() * 1e9).round() / 1e9).item() <= 1.0, msg="Parameter {} of model {} seems not properly initialized".format(name, model_class), @@ -300,6 +452,59 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) +@require_torch +class Wav2Vec2UtilsTest(unittest.TestCase): + def test_compute_mask_indices(self): + batch_size = 4 + sequence_length = 60 + mask_prob = 0.5 + mask_length = 1 + + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + + self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) + + attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long) + attention_mask[:, -sequence_length // 2 :] = 0 + + mask = _compute_mask_indices( + (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask + ) + + self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)]) + + def test_compute_mask_indices_overlap(self): + batch_size = 4 + sequence_length = 60 + mask_prob = 0.5 + mask_length = 4 + + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + + # because of overlap there is a range of possible masks + for batch_sum in mask.sum(axis=-1): + self.assertIn( + int(batch_sum), + list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), + ) + + attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long) + attention_mask[:, -sequence_length // 2 :] = 0 + + mask = _compute_mask_indices( + (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask + ) + + # because of overlap there is a range of possible masks + for batch_sum in mask.sum(axis=-1): + self.assertIn( + int(batch_sum), + list( + range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2)) + ), + ) + + @require_torch @slow @require_datasets