From d35e0c62477d8a99baca3d2ae2e64ec62b64527c Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 16 Mar 2022 17:23:00 +0400 Subject: [PATCH] Minor fixes to XTREME-S (#16193) * Minor fixes * Fix vocab union * Update examples/research_projects/xtreme-s/README.md Co-authored-by: Patrick von Platen * Update README * unused import Co-authored-by: Patrick von Platen --- examples/research_projects/xtreme-s/README.md | 14 +++---- .../xtreme-s/run_xtreme_s.py | 42 ++++++++++--------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/examples/research_projects/xtreme-s/README.md b/examples/research_projects/xtreme-s/README.md index d4031a053f..3c74f634eb 100644 --- a/examples/research_projects/xtreme-s/README.md +++ b/examples/research_projects/xtreme-s/README.md @@ -20,7 +20,7 @@ limitations under the License. The Cross-lingual TRansfer Evaluation of Multilingual Encoders for Speech (XTREME-S) benchmark is a benchmark designed to evaluate speech representations across languages, tasks, domains and data regimes. It covers XX typologically diverse languages and seven downstream tasks grouped in four families: speech recognition, translation, classification and retrieval. -XTREME-S covers speech recognition with BABEL, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (FLoRes) and intent classification (MInds-14) and finally speech retrieval with speech-speech translation data mining (bi-speech retrieval). Each of the tasks covers a subset of the 40 languages included in XTREME-S (shown here with their ISO 639-1 codes): ar, as, ca, cs, cy, da, de, en, en, en, en, es, et, fa, fi, fr, hr, hu, id, it, ja, ka, ko, lo, lt, lv, mn, nl, pl, pt, ro, ru, sk, sl, sv, sw, ta, tl, tr and zh. +XTREME-S covers speech recognition with Fleurs, Multilingual LibriSpeech (MLS) and VoxPopuli, speech translation with CoVoST-2, speech classification with LangID (Fleurs) and intent classification (MInds-14) and finally speech(-text) retrieval with Fleurs. Each of the tasks covers a subset of the 102 languages included in XTREME-S (shown here with their ISO 3166-1 codes): afr, amh, ara, asm, ast, azj, bel, ben, bos, cat, ceb, zho_simpl, zho_trad, ces, cym, dan, deu, ell, eng, spa, est, fas, ful, fin, tgl, fra, gle, glg, guj, hau, heb, hin, hrv, hun, hye, ind, ibo, isl, ita, jpn, jav, kat, kam, kea, kaz, khm, kan, kor, ckb, kir, ltz, lug, lin, lao, lit, luo, lav, mri, mkd, mal, mon, mar, msa, mlt, mya, nob, npi, nld, nso, nya, oci, orm, ory, pan, pol, pus, por, ron, rus, bul, snd, slk, slv, sna, som, srp, swe, swh, tam, tel, tgk, tha, tur, ukr, umb, urd, uzb, vie, wol, xho, yor and zul. Paper: `` @@ -32,16 +32,14 @@ Based on the [`run_xtreme_s.py`](https://github.com/huggingface/transformers/blo This script can fine-tune any of the pretrained speech models on the [hub](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition) on the [XTREME-S dataset](https://huggingface.co/datasets/google/xtreme_s) tasks. -XTREME-S is made up of 7 different task-specific subsets. Here is how to run the script on each of them: +XTREME-S is made up of 7 different tasks. Here is how to run the script on each of them: ```bash export TASK_NAME=mls.all python run_xtreme_s.py \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \ - --dataset_name="google/xtreme_s" \ - --dataset_config_name="${TASK_NAME}" \ - --eval_split_name="validation" \ + --task="${TASK_NAME}" \ --output_dir="xtreme_s_xlsr_${TASK_NAME}" \ --num_train_epochs=100 \ --per_device_train_batch_size=32 \ @@ -49,16 +47,16 @@ python run_xtreme_s.py \ --target_column_name="transcription" \ --save_steps=500 \ --eval_steps=500 \ - --freeze_feature_encoder \ --gradient_checkpointing \ --fp16 \ --group_by_length \ --do_train \ --do_eval \ + --do_predict \ --push_to_hub ``` -where `TASK_NAME` can be one of: `mls.all, voxpopuli, covost2.all, fleurs.all, minds14.all`. +where `TASK_NAME` can be one of: `mls, voxpopuli, covost2, fleurs-asr, fleurs-lang_id, minds14`. We get the following results on the test set of the benchmark's datasets. The corresponding training commands for each dataset are given in the sections below: @@ -109,6 +107,7 @@ python -m torch.distributed.launch \ --group_by_length \ --do_train \ --do_eval \ + --do_predict \ --metric_for_best_model="wer" \ --greater_is_better=False \ --load_best_model_at_end \ @@ -152,6 +151,7 @@ python -m torch.distributed.launch \ --group_by_length \ --do_train \ --do_eval \ + --do_predict \ --metric_for_best_model="f1" \ --greater_is_better=True \ --load_best_model_at_end \ diff --git a/examples/research_projects/xtreme-s/run_xtreme_s.py b/examples/research_projects/xtreme-s/run_xtreme_s.py index cc1d261a89..227380962f 100644 --- a/examples/research_projects/xtreme-s/run_xtreme_s.py +++ b/examples/research_projects/xtreme-s/run_xtreme_s.py @@ -15,7 +15,6 @@ """ Fine-tuning a 🤗 Transformers pretrained speech model on the XTREME-S benchmark tasks""" -import functools import json import logging import os @@ -152,8 +151,8 @@ class DataTrainingArguments: """ dataset_name: str = field( - default="xtreme_s", - metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"}, + default="google/xtreme_s", + metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'google/xtreme_s'"}, ) task: str = field( default=None, @@ -169,21 +168,20 @@ class DataTrainingArguments: train_split_name: str = field( default="train", metadata={ - "help": "The name of the training data set split to use (via the datasets library). " "Defaults to 'train'" + "help": "The name of the training dataset split to use (via the datasets library). Defaults to 'train'" }, ) eval_split_name: str = field( default="validation", metadata={ - "help": "The name of the evaluation data set split to use (via the datasets library). " + "help": "The name of the evaluation dataset split to use (via the datasets library). " "Defaults to 'validation'" }, ) predict_split_name: str = field( default="test", metadata={ - "help": "The name of the prediction data set split to use (via the datasets library). " - "Defaults to 'test'" + "help": "The name of the prediction dataset split to use (via the datasets library). " "Defaults to 'test'" }, ) audio_column_name: str = field( @@ -191,10 +189,10 @@ class DataTrainingArguments: metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, ) target_column_name: str = field( - default="transcription", + default=None, metadata={ "help": "The name of the dataset column containing the target data " - "(transcription/translation/label). Defaults to 'transcription'" + "(transcription/translation/label). If None, the name will be inferred from the task. Defaults to None." }, ) overwrite_cache: bool = field( @@ -348,8 +346,10 @@ def create_vocabulary_from_data( ) # take union of all unique characters in each dataset - vocab_set = functools.reduce( - lambda vocab_1, vocab_2: set(vocab_1["vocab"][0]) | set(vocab_2["vocab"][0]), vocabs.values() + vocab_set = ( + (set(vocabs["train"]["vocab"][0]) if "train" in vocabs else set()) + | (set(vocabs["eval"]["vocab"][0]) if "eval" in vocabs else set()) + | (set(vocabs["predict"]["vocab"][0]) if "predict" in vocabs else set()) ) vocab_dict = {v: k for k, v in enumerate(sorted(list(vocab_set)))} @@ -434,7 +434,10 @@ def main(): " for multi-lingual fine-tuning." ) - target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name] + if data_args.target_column_name is None: + target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name] + else: + target_column_name = data_args.target_column_name # here we differentiate between tasks with text as the target and classification tasks is_text_target = target_column_name in ("transcription", "translation") @@ -457,9 +460,9 @@ def main(): f"{', '.join(raw_datasets['train'].column_names)}." ) - if data_args.target_column_name not in raw_datasets["train"].column_names: + if target_column_name not in raw_datasets["train"].column_names: raise ValueError( - f"--target_column_name {data_args.target_column_name} not found in dataset '{data_args.dataset_name}'. " + f"--target_column_name {target_column_name} not found in dataset '{data_args.dataset_name}'. " "Make sure to set `--target_column_name` to the correct text column - one of " f"{', '.join(raw_datasets['train'].column_names)}." ) @@ -468,7 +471,7 @@ def main(): raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) if not is_text_target: - label_list = raw_datasets["train"].features[data_args.target_column_name].names + label_list = raw_datasets["train"].features[target_column_name].names num_labels = len(label_list) if training_args.do_eval: @@ -684,7 +687,7 @@ def main(): if is_text_target: batch["labels"] = tokenizer(batch["target_text"], **additional_kwargs).input_ids else: - batch["labels"] = batch[data_args.target_column_name] + batch["labels"] = batch[target_column_name] return batch with training_args.main_process_first(desc="dataset map preprocessing"): @@ -809,10 +812,10 @@ def main(): trainer.save_metrics("train", metrics) trainer.save_state() - # Evaluation + # Evaluation on the test set results = {} if training_args.do_predict: - logger.info("*** Predicte ***") + logger.info(f"*** Evaluating on the `{data_args.predict_split_name}` set ***") metrics = trainer.evaluate(vectorized_datasets["predict"]) max_predict_samples = ( data_args.max_predict_samples @@ -831,9 +834,8 @@ def main(): "tags": [task_name, data_args.dataset_name], "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}", "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}", + "language": data_args.language, } - if "common_voice" in data_args.dataset_name: - kwargs["language"] = config_name if training_args.push_to_hub: trainer.push_to_hub(**kwargs)