From 55f49c5f4bf787cce6f04868737291f44a9710ec Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 12 Nov 2021 16:35:57 +0100 Subject: [PATCH] [Wav2Vec2 Example] Improve fine-tuning script (#14373) * improve some stuff * finish * correct last --- .../run_speech_recognition_ctc.py | 32 ++++++++++++++++--- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index 7960c7cc68..638e01e1d6 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -99,9 +99,24 @@ class ModelArguments: metadata={ "help": "Probability of each feature vector along the time axis to be chosen as the start of the vector" "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature" - "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``." + "vectors will be masked along the time axis." }, ) + mask_time_length: Optional[int] = field( + default=10, + metadata={"help": "Length of vector span to mask along the time axis."}, + ) + mask_feature_prob: Optional[float] = field( + default=0.0, + metadata={ + "help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector" + "span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis." + }, + ) + mask_feature_length: Optional[int] = field( + default=10, + metadata={"help": "Length of vector span to mask along the feature axis."}, + ) layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."}) ctc_loss_reduction: Optional[str] = field( default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} @@ -169,6 +184,10 @@ class DataTrainingArguments: default=None, metadata={"help": "A list of characters to remove from the transcripts."}, ) + eval_metrics: Optional[List[str]] = list_field( + default=["wer"], + metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"}, + ) max_duration_in_seconds: Optional[float] = field( default=20.0, metadata={ @@ -446,6 +465,9 @@ def main(): "hidden_dropout": model_args.hidden_dropout, "final_dropout": model_args.final_dropout, "mask_time_prob": model_args.mask_time_prob, + "mask_time_length": model_args.mask_time_length, + "mask_feature_prob": model_args.mask_feature_prob, + "mask_feature_length": model_args.mask_feature_length, "gradient_checkpointing": training_args.gradient_checkpointing, "layerdrop": model_args.layerdrop, "ctc_loss_reduction": model_args.ctc_loss_reduction, @@ -519,8 +541,8 @@ def main(): # Let's use word error rate (WER) as our evaluation metric, # instantiate a data collator and the trainer - # Define Metric during training - wer_metric = load_metric("wer") + # Define evaluation metrics during training, *i.e.* word error rate, character error rate + eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics} # for large datasets it is advised to run the preprocessing on a # single machine first with ``args.preprocessing_only`` since there will mostly likely @@ -541,9 +563,9 @@ def main(): # we do not want to group tokens when computing the metrics label_str = processor.batch_decode(pred.label_ids, group_tokens=False) - wer = wer_metric.compute(predictions=pred_str, references=label_str) + metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()} - return {"wer": wer} + return metrics # Instantiate custom data collator data_collator = DataCollatorCTCWithPadding(processor=processor)