From a4a88fa09f5fcf38d74ffba41ed932f71935932a Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Wed, 27 Apr 2022 08:34:21 +0200 Subject: [PATCH] [Research] Speed up evaluation for XTREME-S (#16785) * Avoid repeated per-lang filtering * Language groups and logits preprocessing * Style --- .../xtreme-s/run_xtreme_s.py | 54 ++++++++++++++++--- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/examples/research_projects/xtreme-s/run_xtreme_s.py b/examples/research_projects/xtreme-s/run_xtreme_s.py index b6a6e7ae2c..a186d4b7ce 100644 --- a/examples/research_projects/xtreme-s/run_xtreme_s.py +++ b/examples/research_projects/xtreme-s/run_xtreme_s.py @@ -136,6 +136,10 @@ class ModelArguments: metadata={"help": "Length of vector span to mask along the feature axis."}, ) layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) + ctc_zero_infinity: bool = field( + default=False, + metadata={"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`."}, + ) ctc_loss_reduction: Optional[str] = field( default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} ) @@ -166,6 +170,15 @@ class DataTrainingArguments: default="all", metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."}, ) + language_group: str = field( + default=None, + metadata={ + "help": "The language group to select a subset of languages to train on. " + "This option is only used the 'fleurs-asr' task. Should be one of: " + "'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', " + "'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'." + }, + ) train_split_name: str = field( default="train", metadata={ @@ -441,6 +454,11 @@ def main(): "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'" " for multi-lingual fine-tuning." ) + if data_args.language_group is not None: + if data_args.task != "fleurs-asr": + raise ValueError("--language_group should only be used with --task=fleurs-asr") + if data_args.language != "all": + raise ValueError("--language_group should only be used with --language=all") if data_args.target_column_name is None: target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name] @@ -502,11 +520,23 @@ def main(): if data_args.max_predict_samples is not None: raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples)) + lang_list = next(iter(raw_datasets.values())).features["lang_id"].names if not is_text_target: label_list = next(iter(raw_datasets.values())).features[target_column_name].names - lang_list = next(iter(raw_datasets.values())).features["lang_id"].names num_labels = len(label_list) + num_workers = data_args.preprocessing_num_workers + + lang_group = data_args.language_group + if lang_group is not None: + with training_args.main_process_first(desc="language group filter"): + lang_group_id = next(iter(raw_datasets.values())).features["lang_group_id"].str2int(lang_group) + raw_datasets = raw_datasets.filter( + lambda lang_group: lang_group == lang_group_id, + num_proc=num_workers, + input_columns=["lang_group_id"], + ) + # 2. We remove some special characters from the datasets # that make training complicated and do not help in transcribing the speech # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic @@ -616,6 +646,7 @@ def main(): "mask_feature_length": model_args.mask_feature_length, "gradient_checkpointing": training_args.gradient_checkpointing, "layerdrop": model_args.layerdrop, + "ctc_zero_infinity": model_args.ctc_zero_infinity, "ctc_loss_reduction": model_args.ctc_loss_reduction, "activation_dropout": model_args.activation_dropout, } @@ -675,7 +706,6 @@ def main(): max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate audio_column_name = data_args.audio_column_name - num_workers = data_args.preprocessing_num_workers # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification phoneme_language = data_args.phoneme_language @@ -740,13 +770,13 @@ def main(): logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") return - def compute_asr_metric(pred): - pred_logits = pred.predictions - pred_ids = np.argmax(pred_logits, axis=-1) + def asr_logits_argmax(logits, labels): + return logits.argmax(dim=-1) + def compute_asr_metric(pred): pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id - pred_str = tokenizer.batch_decode(pred_ids) + pred_str = tokenizer.batch_decode(pred.predictions) # we do not want to group tokens when computing the metrics label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False) @@ -783,6 +813,7 @@ def main(): model=model, data_collator=data_collator, args=training_args, + preprocess_logits_for_metrics=asr_logits_argmax if training_args.predict_with_generate else None, compute_metrics=compute_asr_metric if training_args.predict_with_generate else None, train_dataset=vectorized_datasets["train"] if training_args.do_train else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, @@ -793,6 +824,7 @@ def main(): model=model, data_collator=data_collator, args=training_args, + preprocess_logits_for_metrics=asr_logits_argmax if is_text_target else None, compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric, train_dataset=vectorized_datasets["train"] if training_args.do_train else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, @@ -837,11 +869,17 @@ def main(): average_metrics = defaultdict(list) for lang_id in range(len(lang_list)): lang_name = lang_list[lang_id] - lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id) + with training_args.main_process_first(desc="per-language dataset filter"): + lang_dataset = vectorized_datasets["predict"].filter( + lambda lang: lang == lang_id, + num_proc=num_workers, + input_columns=["lang"], + ) lang_metrics = trainer.evaluate(lang_dataset) + redundant_metrics = ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second", "eval_epoch"] for metric_name, value in lang_metrics.items(): average_metrics[metric_name].append(value) - if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]: + if metric_name not in redundant_metrics: metrics[f"{metric_name}_{lang_name}"] = value for metric_name, value in average_metrics.items(): metrics[metric_name] = np.mean(value)