Add Fine-Tuning for Wav2Vec2 (#10145)
* add encode labels function to tokenizer * start adding finetuning * init dropout * upload * correct convert script * apply changes * fix second typo * make first dummy training run * adapt convert script * push confg for comparison * remove conf * finish training * adapt data collator * add research folder * update according to fairseq feedback * some minor corrections * refactor masking indices a bit * some minor changes * clean tokenizer * finish clean-up * remove previous logic * update run script * correct training * finish changes * finish model * correct bug * fix training a bit more * add some tests * finish gradient checkpointing * finish example * correct gradient checkpointing * improve tokenization method * revert changes in tokenizer * revert general change * adapt fine-tuning * update * save intermediate test * Update README.md * finish finetuning * delete conversion script * Update src/transformers/models/wav2vec2/configuration_wav2vec2.py * Update src/transformers/models/wav2vec2/processing_wav2vec2.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * finish wav2vec2 script * finish wav2vec2 fine-tuning * finalize test * correct test * adapt tests * finish * remove test file Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
3c733f3208
commit
0234de8418
8
examples/research_projects/wav2vec2/README.md
Normal file
8
examples/research_projects/wav2vec2/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
## Fine-tuning Wav2Vec2
|
||||
|
||||
The `run_training.py` script allows one to finetune pretrained Wav2Vec2 models that can be found [here](https://huggingface.co/models?search=facebook/wav2vec2).
|
||||
|
||||
This finetuning script can also be run as a google colab [TODO: here]( ).
|
||||
|
||||
The script is actively maintained by [Patrick von Platen](https://github.com/patrickvonplaten).
|
||||
Feel free to ask a question on the [Forum](https://discuss.huggingface.co/) or post an issue on [GitHub](https://github.com/huggingface/transformers/issues/new/choose) and adding `@patrickvonplaten` as a tag.
|
||||
21
examples/research_projects/wav2vec2/finetune_base_100.sh
Executable file
21
examples/research_projects/wav2vec2/finetune_base_100.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-base-100h" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="32" \
|
||||
--per_device_eval_batch_size="32" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_total_limit="3" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-base" \
|
||||
--fp16 \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--preprocessing_num_workers="32" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor
|
||||
21
examples/research_projects/wav2vec2/finetune_large_lv60_100.sh
Executable file
21
examples/research_projects/wav2vec2/finetune_large_lv60_100.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
python run_asr.py \
|
||||
--output_dir="./wav2vec2-large-lv60-100h" \
|
||||
--num_train_epochs="30" \
|
||||
--per_device_train_batch_size="16" \
|
||||
--per_device_eval_batch_size="16" \
|
||||
--evaluation_strategy="steps" \
|
||||
--save_total_limit="3" \
|
||||
--save_steps="500" \
|
||||
--eval_steps="100" \
|
||||
--logging_steps="50" \
|
||||
--learning_rate="5e-4" \
|
||||
--warmup_steps="3000" \
|
||||
--model_name_or_path="facebook/wav2vec2-large-lv60" \
|
||||
--fp16 \
|
||||
--dataset_name="librispeech_asr" \
|
||||
--dataset_config_name="clean" \
|
||||
--train_split_name="train.100" \
|
||||
--preprocessing_num_workers="32" \
|
||||
--group_by_length \
|
||||
--freeze_feature_extractor
|
||||
4
examples/research_projects/wav2vec2/requirements.txt
Normal file
4
examples/research_projects/wav2vec2/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
transformers
|
||||
datasets
|
||||
torch >= 1.5.0
|
||||
jiwer
|
||||
281
examples/research_projects/wav2vec2/run_asr.py
Executable file
281
examples/research_projects/wav2vec2/run_asr.py
Executable file
@@ -0,0 +1,281 @@
|
||||
#!/usr/bin/env python3
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
import soundfile as sf
|
||||
from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2Processor,
|
||||
is_apex_available,
|
||||
)
|
||||
|
||||
|
||||
if is_apex_available():
|
||||
from apex import amp
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("1.6"):
|
||||
_is_native_amp_available = True
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
||||
)
|
||||
freeze_feature_extractor: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataTrainingArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
|
||||
Using `HfArgumentParser` we can turn this class
|
||||
into argparse arguments to be able to specify them on
|
||||
the command line.
|
||||
"""
|
||||
|
||||
dataset_name: str = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
dataset_config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
train_split_name: Optional[str] = field(
|
||||
default="train",
|
||||
metadata={
|
||||
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorCTCWithPadding:
|
||||
"""
|
||||
Data collator that will dynamically pad the inputs received.
|
||||
Args:
|
||||
processor (:class:`~transformers.Wav2Vec2Processor`)
|
||||
The processor used for proccessing the data.
|
||||
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||
among:
|
||||
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
||||
sequence if provided).
|
||||
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||
maximum acceptable input length for the model if that argument is not provided.
|
||||
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
|
||||
max_length_labels (:obj:`int`, `optional`):
|
||||
Maximum length of the ``labels`` returned list and optionally padding length (see above).
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||
7.5 (Volta).
|
||||
"""
|
||||
|
||||
processor: Wav2Vec2Processor
|
||||
padding: Union[bool, str] = True
|
||||
max_length: Optional[int] = None
|
||||
max_length_labels: Optional[int] = None
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
pad_to_multiple_of_labels: Optional[int] = None
|
||||
|
||||
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
||||
# split inputs and labels since they have to be of different lenghts and need
|
||||
# different padding methods
|
||||
input_features = [{"input_values": feature["input_values"]} for feature in features]
|
||||
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
||||
|
||||
batch = self.processor.pad(
|
||||
input_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.processor.as_target_processor():
|
||||
labels_batch = self.processor.pad(
|
||||
label_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length_labels,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of_labels,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# replace padding with -100 to ignore loss correctly
|
||||
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
||||
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class CTCTrainer(Trainer):
|
||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to train.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
|
||||
Return:
|
||||
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
if model.module.config.ctc_loss_reduction == "mean":
|
||||
loss = loss.mean()
|
||||
elif model.module.config.ctc_loss_reduction == "sum":
|
||||
loss = loss.sum() / (inputs["labels"] >= 0).sum()
|
||||
else:
|
||||
raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
|
||||
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
if self.use_amp:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
elif self.deepspeed:
|
||||
self.deepspeed.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
return loss.detach()
|
||||
|
||||
|
||||
def main():
|
||||
# See all possible arguments in src/transformers/training_args.py
|
||||
# or by passing the --help flag to this script.
|
||||
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
||||
|
||||
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
||||
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
model = Wav2Vec2ForCTC.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
processor = Wav2Vec2Processor.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
||||
|
||||
train_dataset = datasets.load_dataset(
|
||||
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name
|
||||
)
|
||||
val_dataset = datasets.load_dataset(data_args.dataset_name, data_args.dataset_config_name, split="validation")
|
||||
|
||||
wer_metric = datasets.load_metric("wer")
|
||||
|
||||
def map_to_array(batch):
|
||||
speech_array, sampling_rate = sf.read(batch["file"])
|
||||
batch["speech"] = speech_array
|
||||
batch["sampling_rate"] = sampling_rate
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(map_to_array, remove_columns=["file"])
|
||||
val_dataset = val_dataset.map(map_to_array, remove_columns=["file"])
|
||||
|
||||
def prepare_dataset(batch):
|
||||
# check that all files have the correct sampling rate
|
||||
assert (
|
||||
len(set(batch["sampling_rate"])) == 1
|
||||
), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
|
||||
|
||||
batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
|
||||
with processor.as_target_processor():
|
||||
batch["labels"] = processor(batch["text"]).input_ids
|
||||
return batch
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
prepare_dataset,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
)
|
||||
val_dataset = val_dataset.map(
|
||||
prepare_dataset,
|
||||
batch_size=training_args.per_device_train_batch_size,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
)
|
||||
|
||||
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
|
||||
|
||||
def compute_metrics(pred):
|
||||
pred_logits = pred.predictions
|
||||
pred_ids = np.argmax(pred_logits, axis=-1)
|
||||
|
||||
pred.label_ids[pred.label_ids == -100] = 0
|
||||
|
||||
pred_str = processor.batch_decode(pred_ids)
|
||||
# we do not want to group tokens when computing the metrics
|
||||
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
||||
|
||||
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
||||
|
||||
return {"wer": wer}
|
||||
|
||||
if model_args.freeze_feature_extractor:
|
||||
model.freeze_feature_extractor()
|
||||
|
||||
trainer = CTCTrainer(
|
||||
model=model,
|
||||
data_collator=data_collator,
|
||||
args=training_args,
|
||||
compute_metrics=compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
tokenizer=processor.feature_extractor,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user