From c2dc89be6246de85fa7085d46a8a746a9ace66cc Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 16 Mar 2022 01:21:31 +0100 Subject: [PATCH] [Xtreme-S] fix some namings (#16183) --- .../{xreme-s => xtreme-s}/README.md | 11 +-- .../{xreme-s => xtreme-s}/requirements.txt | 0 .../{xreme-s => xtreme-s}/run_xtreme_s.py | 99 ++++++++++++++----- 3 files changed, 81 insertions(+), 29 deletions(-) rename examples/research_projects/{xreme-s => xtreme-s}/README.md (96%) rename examples/research_projects/{xreme-s => xtreme-s}/requirements.txt (100%) rename examples/research_projects/{xreme-s => xtreme-s}/run_xtreme_s.py (91%) diff --git a/examples/research_projects/xreme-s/README.md b/examples/research_projects/xtreme-s/README.md similarity index 96% rename from examples/research_projects/xreme-s/README.md rename to examples/research_projects/xtreme-s/README.md index 79e06f1041..d4031a053f 100644 --- a/examples/research_projects/xreme-s/README.md +++ b/examples/research_projects/xtreme-s/README.md @@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/ python -m torch.distributed.launch \ --nproc_per_node=8 \ run_xtreme_s.py \ + --task="mls" \ + --language="all" \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \ - --dataset_name="google/xtreme_s" \ - --dataset_config_name="mls.all" \ --eval_split_name="test" \ --output_dir="xtreme_s_xlsr_300m_mls" \ --overwrite_output_dir \ @@ -94,7 +94,6 @@ python -m torch.distributed.launch \ --learning_rate="3e-4" \ --warmup_steps=3000 \ --evaluation_strategy="steps" \ - --target_column_name="transcription" \ --max_duration_in_seconds=20 \ --save_steps=500 \ --eval_steps=500 \ @@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/ python -m torch.distributed.launch \ --nproc_per_node=2 \ run_xtreme_s.py \ + --task="minds14" \ + --language="all" \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \ - --dataset_name="google/xtreme_s" \ - --dataset_config_name="minds14.all" \ - --eval_split_name="test" \ --output_dir="xtreme_s_xlsr_300m_minds14" \ --overwrite_output_dir \ --num_train_epochs=50 \ @@ -139,7 +137,6 @@ python -m torch.distributed.launch \ --learning_rate="3e-4" \ --warmup_steps=1500 \ --evaluation_strategy="steps" \ - --target_column_name="intent_class" \ --max_duration_in_seconds=30 \ --save_steps=200 \ --eval_steps=200 \ diff --git a/examples/research_projects/xreme-s/requirements.txt b/examples/research_projects/xtreme-s/requirements.txt similarity index 100% rename from examples/research_projects/xreme-s/requirements.txt rename to examples/research_projects/xtreme-s/requirements.txt diff --git a/examples/research_projects/xreme-s/run_xtreme_s.py b/examples/research_projects/xtreme-s/run_xtreme_s.py similarity index 91% rename from examples/research_projects/xreme-s/run_xtreme_s.py rename to examples/research_projects/xtreme-s/run_xtreme_s.py index ee51ece1a0..cc1d261a89 100644 --- a/examples/research_projects/xreme-s/run_xtreme_s.py +++ b/examples/research_projects/xtreme-s/run_xtreme_s.py @@ -62,6 +62,17 @@ def list_field(default=None, metadata=None): return field(default_factory=lambda: default, metadata=metadata) +TASK_TO_TARGET_COLUMN_NAME = { + "fleurs-asr": "transcription", + "fleurs-lang_id": "lang_id", + "mls": "transcription", + "voxpopuli": "transcription", + "covost2": "translation", + "minds14": "intent_class", + "babel": "transcription", +} + + @dataclass class ModelArguments: """ @@ -144,8 +155,16 @@ class DataTrainingArguments: default="xtreme_s", metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"}, ) - dataset_config_name: str = field( - default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + task: str = field( + default=None, + metadata={ + "help": "The task name of the benchmark to use (via the datasets library). Should be on of: " + "'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'." + }, + ) + language: str = field( + default="all", + metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."}, ) train_split_name: str = field( default="train", @@ -160,6 +179,13 @@ class DataTrainingArguments: "Defaults to 'validation'" }, ) + predict_split_name: str = field( + default="test", + metadata={ + "help": "The name of the prediction data set split to use (via the datasets library). " + "Defaults to 'test'" + }, + ) audio_column_name: str = field( default="audio", metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, @@ -192,6 +218,13 @@ class DataTrainingArguments: "value if set." }, ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) chars_to_ignore: Optional[List[str]] = list_field( default=', ? . ! - ; : " “ % ‘ ” �'.split(" "), metadata={"help": "A list of characters to remove from the transcripts."}, @@ -387,22 +420,31 @@ def main(): # 1. First, let's load the dataset raw_datasets = DatasetDict() - if data_args.dataset_config_name is None: + task_name = data_args.task + lang_id = data_args.language + + if task_name is None: raise ValueError( - "Set --dataset_config_name should be set to '.' " - "(e.g. 'mls.pl', 'covost2.en.tr', 'minds14.fr-FR') " - "or '.all' for multi-lingual fine-tuning." + "Set --task should be set to '' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') " + ) + if lang_id is None: + raise ValueError( + "Set --language should be set to the language id of the sub dataset " + "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'" + " for multi-lingual fine-tuning." ) - task_name = data_args.dataset_config_name.split(".")[0] - target_column_name = data_args.target_column_name + target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name] + # here we differentiate between tasks with text as the target and classification tasks is_text_target = target_column_name in ("transcription", "translation") + config_name = ".".join([task_name.split("-")[0], lang_id]) + if training_args.do_train: raw_datasets["train"] = load_dataset( data_args.dataset_name, - data_args.dataset_config_name, + config_name, split=data_args.train_split_name, use_auth_token=data_args.use_auth_token, cache_dir=model_args.cache_dir, @@ -432,7 +474,7 @@ def main(): if training_args.do_eval: raw_datasets["eval"] = load_dataset( data_args.dataset_name, - data_args.dataset_config_name, + config_name, split=data_args.eval_split_name, use_auth_token=data_args.use_auth_token, cache_dir=model_args.cache_dir, @@ -441,6 +483,18 @@ def main(): if data_args.max_eval_samples is not None: raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) + if training_args.do_predict: + raw_datasets["predict"] = load_dataset( + data_args.dataset_name, + config_name, + split=data_args.predict_split_name, + use_auth_token=data_args.use_auth_token, + cache_dir=model_args.cache_dir, + ) + + if data_args.max_predict_samples is not None: + raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples)) + # 2. We remove some special characters from the datasets # that make training complicated and do not help in transcribing the speech # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic @@ -757,24 +811,25 @@ def main(): # Evaluation results = {} - if training_args.do_eval: - logger.info("*** Evaluate ***") - metrics = trainer.evaluate() - max_eval_samples = ( - data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"]) + if training_args.do_predict: + logger.info("*** Predicte ***") + metrics = trainer.evaluate(vectorized_datasets["predict"]) + max_predict_samples = ( + data_args.max_predict_samples + if data_args.max_predict_samples is not None + else len(vectorized_datasets["predict"]) ) - metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"])) + metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"])) - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) # Write model card and (optionally) push to hub - config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na" kwargs = { "finetuned_from": model_args.model_name_or_path, - "tasks": "speech-recognition", - "tags": ["automatic-speech-recognition", data_args.dataset_name], - "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}", + "tasks": task_name, + "tags": [task_name, data_args.dataset_name], + "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}", "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}", } if "common_voice" in data_args.dataset_name: