[xtreme-s] Update Minds14 results (#16241)
* update results * per-language metrics * Format the per-language metrics
This commit is contained in:
@@ -67,7 +67,7 @@ The corresponding training commands for each dataset are given in the sections b
|
||||
| Speech Recognition | VoxPopuli | - | - | - | - |
|
||||
| Speech Recognition | FLEURS | - | - | - | - |
|
||||
| Speech Translation | CoVoST-2 | - | - | - | - |
|
||||
| Speech Classification | Minds-14 | 94.74 F1 / 94.70 Acc. | [here](https://huggingface.co/anton-l/xtreme_s_xlsr_300m_minds14/) | 04:46:40 | 2xA100 |
|
||||
| Speech Classification | Minds-14 | 90.15 F1 / 90.33 Acc. | [here](https://huggingface.co/anton-l/xtreme_s_xlsr_300m_minds14/) | 2:54:21 | 2xA100 |
|
||||
| Speech Classification | FLEURS | - | - | - | - |
|
||||
| Speech Retrieval | FLEURS | - | - | - | - |
|
||||
|
||||
@@ -82,7 +82,6 @@ python -m torch.distributed.launch \
|
||||
--task="mls" \
|
||||
--language="all" \
|
||||
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
|
||||
--eval_split_name="test" \
|
||||
--output_dir="xtreme_s_xlsr_300m_mls" \
|
||||
--overwrite_output_dir \
|
||||
--num_train_epochs=100 \
|
||||
@@ -158,4 +157,4 @@ python -m torch.distributed.launch \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
On 2 A100 GPUs, this script should run in ~5 hours and yield a cross-entropy loss of **0.2890** and F1 score of **94.74**
|
||||
On 2 A100 GPUs, this script should run in ~5 hours and yield a cross-entropy loss of **0.4119** and F1 score of **90.15**
|
||||
|
||||
@@ -20,6 +20,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import OrderedDict, defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -273,6 +274,13 @@ class DataTrainingArguments:
|
||||
" input audio to a sequence of phoneme sequences."
|
||||
},
|
||||
)
|
||||
per_lang_metrics: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If `True`, compute the test metrics separately for each language, and average the results. "
|
||||
"If `False` compute the average test metrics in a single pass for all languages at once."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -470,10 +478,6 @@ def main():
|
||||
if data_args.max_train_samples is not None:
|
||||
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
||||
|
||||
if not is_text_target:
|
||||
label_list = raw_datasets["train"].features[target_column_name].names
|
||||
num_labels = len(label_list)
|
||||
|
||||
if training_args.do_eval:
|
||||
raw_datasets["eval"] = load_dataset(
|
||||
data_args.dataset_name,
|
||||
@@ -498,6 +502,11 @@ def main():
|
||||
if data_args.max_predict_samples is not None:
|
||||
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
|
||||
|
||||
if not is_text_target:
|
||||
label_list = next(iter(raw_datasets.values())).features[target_column_name].names
|
||||
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
|
||||
num_labels = len(label_list)
|
||||
|
||||
# 2. We remove some special characters from the datasets
|
||||
# that make training complicated and do not help in transcribing the speech
|
||||
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
||||
@@ -593,31 +602,33 @@ def main():
|
||||
)
|
||||
|
||||
# adapt config
|
||||
config.update(
|
||||
{
|
||||
"feat_proj_dropout": model_args.feat_proj_dropout,
|
||||
"attention_dropout": model_args.attention_dropout,
|
||||
"hidden_dropout": model_args.hidden_dropout,
|
||||
"final_dropout": model_args.final_dropout,
|
||||
"mask_time_prob": model_args.mask_time_prob,
|
||||
"mask_time_length": model_args.mask_time_length,
|
||||
"mask_feature_prob": model_args.mask_feature_prob,
|
||||
"mask_feature_length": model_args.mask_feature_length,
|
||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||
"layerdrop": model_args.layerdrop,
|
||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||
"activation_dropout": model_args.activation_dropout,
|
||||
}
|
||||
)
|
||||
if training_args.do_train:
|
||||
if is_text_target:
|
||||
config.pad_token_id = tokenizer.pad_token_id
|
||||
config.vocab_size = len(tokenizer)
|
||||
else:
|
||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||
config.label2id = label_to_id
|
||||
config.id2label = {id: label for label, id in label_to_id.items()}
|
||||
config.num_labels = num_labels
|
||||
# (speech translation requires pre-configured seq2seq models)
|
||||
if task_name != "covost2":
|
||||
config.update(
|
||||
{
|
||||
"feat_proj_dropout": model_args.feat_proj_dropout,
|
||||
"attention_dropout": model_args.attention_dropout,
|
||||
"hidden_dropout": model_args.hidden_dropout,
|
||||
"final_dropout": model_args.final_dropout,
|
||||
"mask_time_prob": model_args.mask_time_prob,
|
||||
"mask_time_length": model_args.mask_time_length,
|
||||
"mask_feature_prob": model_args.mask_feature_prob,
|
||||
"mask_feature_length": model_args.mask_feature_length,
|
||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||
"layerdrop": model_args.layerdrop,
|
||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||
"activation_dropout": model_args.activation_dropout,
|
||||
}
|
||||
)
|
||||
if training_args.do_train:
|
||||
if is_text_target:
|
||||
config.pad_token_id = tokenizer.pad_token_id
|
||||
config.vocab_size = len(tokenizer)
|
||||
else:
|
||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||
config.label2id = label_to_id
|
||||
config.id2label = {id: label for label, id in label_to_id.items()}
|
||||
config.num_labels = num_labels
|
||||
|
||||
# create model
|
||||
if target_column_name == "transcription":
|
||||
@@ -688,6 +699,9 @@ def main():
|
||||
batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
|
||||
else:
|
||||
batch["labels"] = batch[target_column_name]
|
||||
|
||||
batch["lang"] = batch["lang_id"]
|
||||
|
||||
return batch
|
||||
|
||||
with training_args.main_process_first(desc="dataset map preprocessing"):
|
||||
@@ -752,7 +766,8 @@ def main():
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
config.save_pretrained(training_args.output_dir)
|
||||
# wait until configs are saved in the main process before loading the processor
|
||||
torch.distributed.barrier()
|
||||
if training_args.local_rank != -1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if is_text_target:
|
||||
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
||||
@@ -816,7 +831,22 @@ def main():
|
||||
results = {}
|
||||
if training_args.do_predict:
|
||||
logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***")
|
||||
metrics = trainer.evaluate(vectorized_datasets["predict"])
|
||||
if data_args.per_lang_metrics:
|
||||
# separate the `test` dataset into language-specific subsets and compute metrics for each of them
|
||||
metrics = {}
|
||||
average_metrics = defaultdict(list)
|
||||
for lang_id in range(len(lang_list)):
|
||||
lang_name = lang_list[lang_id]
|
||||
lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id)
|
||||
lang_metrics = trainer.evaluate(lang_dataset)
|
||||
for metric_name, value in lang_metrics.items():
|
||||
average_metrics[metric_name].append(value)
|
||||
if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]:
|
||||
metrics[f"{metric_name}_{lang_name}"] = value
|
||||
for metric_name, value in average_metrics.items():
|
||||
metrics[metric_name] = np.mean(value)
|
||||
else:
|
||||
metrics = trainer.evaluate(vectorized_datasets["predict"])
|
||||
max_predict_samples = (
|
||||
data_args.max_predict_samples
|
||||
if data_args.max_predict_samples is not None
|
||||
@@ -824,6 +854,9 @@ def main():
|
||||
)
|
||||
metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"]))
|
||||
|
||||
# make sure that the `predict` metrics end up in the log history for the model card
|
||||
trainer.log(OrderedDict(sorted(metrics.items())))
|
||||
|
||||
trainer.log_metrics("predict", metrics)
|
||||
trainer.save_metrics("predict", metrics)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user