[Wav2Vec2 Example] Improve fine-tuning script (#14373)
* improve some stuff * finish * correct last
This commit is contained in:
committed by
GitHub
parent
21546e59a6
commit
55f49c5f4b
@@ -99,9 +99,24 @@ class ModelArguments:
|
|||||||
metadata={
|
metadata={
|
||||||
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
|
||||||
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
|
||||||
"vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
|
"vectors will be masked along the time axis."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
mask_time_length: Optional[int] = field(
|
||||||
|
default=10,
|
||||||
|
metadata={"help": "Length of vector span to mask along the time axis."},
|
||||||
|
)
|
||||||
|
mask_feature_prob: Optional[float] = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={
|
||||||
|
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
|
||||||
|
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mask_feature_length: Optional[int] = field(
|
||||||
|
default=10,
|
||||||
|
metadata={"help": "Length of vector span to mask along the feature axis."},
|
||||||
|
)
|
||||||
layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
|
layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
|
||||||
ctc_loss_reduction: Optional[str] = field(
|
ctc_loss_reduction: Optional[str] = field(
|
||||||
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
||||||
@@ -169,6 +184,10 @@ class DataTrainingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "A list of characters to remove from the transcripts."},
|
metadata={"help": "A list of characters to remove from the transcripts."},
|
||||||
)
|
)
|
||||||
|
eval_metrics: Optional[List[str]] = list_field(
|
||||||
|
default=["wer"],
|
||||||
|
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
|
||||||
|
)
|
||||||
max_duration_in_seconds: Optional[float] = field(
|
max_duration_in_seconds: Optional[float] = field(
|
||||||
default=20.0,
|
default=20.0,
|
||||||
metadata={
|
metadata={
|
||||||
@@ -446,6 +465,9 @@ def main():
|
|||||||
"hidden_dropout": model_args.hidden_dropout,
|
"hidden_dropout": model_args.hidden_dropout,
|
||||||
"final_dropout": model_args.final_dropout,
|
"final_dropout": model_args.final_dropout,
|
||||||
"mask_time_prob": model_args.mask_time_prob,
|
"mask_time_prob": model_args.mask_time_prob,
|
||||||
|
"mask_time_length": model_args.mask_time_length,
|
||||||
|
"mask_feature_prob": model_args.mask_feature_prob,
|
||||||
|
"mask_feature_length": model_args.mask_feature_length,
|
||||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||||
"layerdrop": model_args.layerdrop,
|
"layerdrop": model_args.layerdrop,
|
||||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||||
@@ -519,8 +541,8 @@ def main():
|
|||||||
# Let's use word error rate (WER) as our evaluation metric,
|
# Let's use word error rate (WER) as our evaluation metric,
|
||||||
# instantiate a data collator and the trainer
|
# instantiate a data collator and the trainer
|
||||||
|
|
||||||
# Define Metric during training
|
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
|
||||||
wer_metric = load_metric("wer")
|
eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
|
||||||
|
|
||||||
# for large datasets it is advised to run the preprocessing on a
|
# for large datasets it is advised to run the preprocessing on a
|
||||||
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
# single machine first with ``args.preprocessing_only`` since there will mostly likely
|
||||||
@@ -541,9 +563,9 @@ def main():
|
|||||||
# we do not want to group tokens when computing the metrics
|
# we do not want to group tokens when computing the metrics
|
||||||
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
|
||||||
|
|
||||||
wer = wer_metric.compute(predictions=pred_str, references=label_str)
|
metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
|
||||||
|
|
||||||
return {"wer": wer}
|
return metrics
|
||||||
|
|
||||||
# Instantiate custom data collator
|
# Instantiate custom data collator
|
||||||
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
data_collator = DataCollatorCTCWithPadding(processor=processor)
|
||||||
|
|||||||
Reference in New Issue
Block a user