[xtreme-s] Update Minds14 results (#16241)

* update results

* per-language metrics

* Format the per-language metrics
This commit is contained in:
Anton Lozhkov
2022-03-21 22:33:59 +04:00
committed by GitHub
parent 6f1727d83a
commit e226a24f84
2 changed files with 66 additions and 34 deletions

View File

@@ -67,7 +67,7 @@ The corresponding training commands for each dataset are given in the sections b
| Speech Recognition | VoxPopuli | - | - | - | - | | Speech Recognition | VoxPopuli | - | - | - | - |
| Speech Recognition | FLEURS | - | - | - | - | | Speech Recognition | FLEURS | - | - | - | - |
| Speech Translation | CoVoST-2 | - | - | - | - | | Speech Translation | CoVoST-2 | - | - | - | - |
| Speech Classification | Minds-14 | 94.74 F1 / 94.70 Acc. | [here](https://huggingface.co/anton-l/xtreme_s_xlsr_300m_minds14/) | 04:46:40 | 2xA100 | | Speech Classification | Minds-14 | 90.15 F1 / 90.33 Acc. | [here](https://huggingface.co/anton-l/xtreme_s_xlsr_300m_minds14/) | 2:54:21 | 2xA100 |
| Speech Classification | FLEURS | - | - | - | - | | Speech Classification | FLEURS | - | - | - | - |
| Speech Retrieval | FLEURS | - | - | - | - | | Speech Retrieval | FLEURS | - | - | - | - |
@@ -82,7 +82,6 @@ python -m torch.distributed.launch \
--task="mls" \ --task="mls" \
--language="all" \ --language="all" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--eval_split_name="test" \
--output_dir="xtreme_s_xlsr_300m_mls" \ --output_dir="xtreme_s_xlsr_300m_mls" \
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs=100 \ --num_train_epochs=100 \
@@ -158,4 +157,4 @@ python -m torch.distributed.launch \
--push_to_hub --push_to_hub
``` ```
On 2 A100 GPUs, this script should run in ~5 hours and yield a cross-entropy loss of **0.2890** and F1 score of **94.74** On 2 A100 GPUs, this script should run in ~5 hours and yield a cross-entropy loss of **0.4119** and F1 score of **90.15**

View File

@@ -20,6 +20,7 @@ import logging
import os import os
import re import re
import sys import sys
from collections import OrderedDict, defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@@ -273,6 +274,13 @@ class DataTrainingArguments:
" input audio to a sequence of phoneme sequences." " input audio to a sequence of phoneme sequences."
}, },
) )
per_lang_metrics: bool = field(
default=True,
metadata={
"help": "If `True`, compute the test metrics separately for each language, and average the results. "
"If `False` compute the average test metrics in a single pass for all languages at once."
},
)
@dataclass @dataclass
@@ -470,10 +478,6 @@ def main():
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if not is_text_target:
label_list = raw_datasets["train"].features[target_column_name].names
num_labels = len(label_list)
if training_args.do_eval: if training_args.do_eval:
raw_datasets["eval"] = load_dataset( raw_datasets["eval"] = load_dataset(
data_args.dataset_name, data_args.dataset_name,
@@ -498,6 +502,11 @@ def main():
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples)) raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
if not is_text_target:
label_list = next(iter(raw_datasets.values())).features[target_column_name].names
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
num_labels = len(label_list)
# 2. We remove some special characters from the datasets # 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech # that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
@@ -593,6 +602,8 @@ def main():
) )
# adapt config # adapt config
# (speech translation requires pre-configured seq2seq models)
if task_name != "covost2":
config.update( config.update(
{ {
"feat_proj_dropout": model_args.feat_proj_dropout, "feat_proj_dropout": model_args.feat_proj_dropout,
@@ -688,6 +699,9 @@ def main():
batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids
else: else:
batch["labels"] = batch[target_column_name] batch["labels"] = batch[target_column_name]
batch["lang"] = batch["lang_id"]
return batch return batch
with training_args.main_process_first(desc="dataset map preprocessing"): with training_args.main_process_first(desc="dataset map preprocessing"):
@@ -752,6 +766,7 @@ def main():
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir)
# wait until configs are saved in the main process before loading the processor # wait until configs are saved in the main process before loading the processor
if training_args.local_rank != -1:
torch.distributed.barrier() torch.distributed.barrier()
if is_text_target: if is_text_target:
@@ -816,6 +831,21 @@ def main():
results = {} results = {}
if training_args.do_predict: if training_args.do_predict:
logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***") logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***")
if data_args.per_lang_metrics:
# separate the `test` dataset into language-specific subsets and compute metrics for each of them
metrics = {}
average_metrics = defaultdict(list)
for lang_id in range(len(lang_list)):
lang_name = lang_list[lang_id]
lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id)
lang_metrics = trainer.evaluate(lang_dataset)
for metric_name, value in lang_metrics.items():
average_metrics[metric_name].append(value)
if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]:
metrics[f"{metric_name}_{lang_name}"] = value
for metric_name, value in average_metrics.items():
metrics[metric_name] = np.mean(value)
else:
metrics = trainer.evaluate(vectorized_datasets["predict"]) metrics = trainer.evaluate(vectorized_datasets["predict"])
max_predict_samples = ( max_predict_samples = (
data_args.max_predict_samples data_args.max_predict_samples
@@ -824,6 +854,9 @@ def main():
) )
metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"])) metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"]))
# make sure that the `predict` metrics end up in the log history for the model card
trainer.log(OrderedDict(sorted(metrics.items())))
trainer.log_metrics("predict", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics) trainer.save_metrics("predict", metrics)