[Wav2Vec2 Example] Improve fine-tuning script (#14373)
* improve some stuff * finish * correct last
This commit is contained in:
committed by
GitHub
parent
21546e59a6
commit
55f49c5f4b
@@ -99,9 +99,24 @@ class ModelArguments:
|
||||
metadata={
|
||||
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
||||
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
||||
"vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
|
||||
"vectors will be masked along the time axis."
|
||||
},
|
||||
)
|
||||
mask_time_length: Optional[int] = field(
|
||||
default=10,
|
||||
metadata={"help": "Length of vector span to mask along the time axis."},
|
||||
)
|
||||
mask_feature_prob: Optional[float] = field(
|
||||
default=0.0,
|
||||
metadata={
|
||||
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
||||
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
|
||||
},
|
||||
)
|
||||
mask_feature_length: Optional[int] = field(
|
||||
default=10,
|
||||
metadata={"help": "Length of vector span to mask along the feature axis."},
|
||||
)
|
||||
layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
|
||||
ctc_loss_reduction: Optional[str] = field(
|
||||
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
||||
@@ -169,6 +184,10 @@ class DataTrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||
)
|
||||
eval_metrics: Optional[List[str]] = list_field(
|
||||
default=["wer"],
|
||||
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
|
||||
)
|
||||
max_duration_in_seconds: Optional[float] = field(
|
||||
default=20.0,
|
||||
metadata={
|
||||
@@ -446,6 +465,9 @@ def main():
|
||||
"hidden_dropout": model_args.hidden_dropout,
|
||||
"final_dropout": model_args.final_dropout,
|
||||
"mask_time_prob": model_args.mask_time_prob,
|
||||
"mask_time_length": model_args.mask_time_length,
|
||||
"mask_feature_prob": model_args.mask_feature_prob,
|
||||
"mask_feature_length": model_args.mask_feature_length,
|
||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||
"layerdrop": model_args.layerdrop,
|
||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||
@@ -519,8 +541,8 @@ def main():
|
||||
# Let's use word error rate (WER) as our evaluation metric,
|
||||
# instantiate a data collator and the trainer
|
||||
|
||||
# Define Metric during training
|
||||
wer_metric = load_metric("wer")
|
||||
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
|
||||
eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
|
||||
|
||||
# for large datasets it is advised to run the preprocessing on a
|
||||
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||
@@ -541,9 +563,9 @@ def main():
|
||||
# we do not want to group tokens when computing the metrics
|
||||
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
||||
|
||||
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
||||
metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
|
||||
|
||||
return {"wer": wer}
|
||||
return metrics
|
||||
|
||||
# Instantiate custom data collator
|
||||
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
||||
|
||||
Reference in New Issue
Block a user