[Research] Speed up evaluation for XTREME-S (#16785)
* Avoid repeated per-lang filtering * Language groups and logits preprocessing * Style
This commit is contained in:
@@ -136,6 +136,10 @@ class ModelArguments:
|
|||||||
metadata={"help": "Length of vector span to mask along the feature axis."},
|
metadata={"help": "Length of vector span to mask along the feature axis."},
|
||||||
)
|
)
|
||||||
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
|
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
|
||||||
|
ctc_zero_infinity: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`."},
|
||||||
|
)
|
||||||
ctc_loss_reduction: Optional[str] = field(
|
ctc_loss_reduction: Optional[str] = field(
|
||||||
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
|
||||||
)
|
)
|
||||||
@@ -166,6 +170,15 @@ class DataTrainingArguments:
|
|||||||
default="all",
|
default="all",
|
||||||
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
|
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
|
||||||
)
|
)
|
||||||
|
language_group: str = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The language group to select a subset of languages to train on. "
|
||||||
|
"This option is only used the 'fleurs-asr' task. Should be one of: "
|
||||||
|
"'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
|
||||||
|
"'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
|
||||||
|
},
|
||||||
|
)
|
||||||
train_split_name: str = field(
|
train_split_name: str = field(
|
||||||
default="train",
|
default="train",
|
||||||
metadata={
|
metadata={
|
||||||
@@ -441,6 +454,11 @@ def main():
|
|||||||
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
|
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
|
||||||
" for multi-lingual fine-tuning."
|
" for multi-lingual fine-tuning."
|
||||||
)
|
)
|
||||||
|
if data_args.language_group is not None:
|
||||||
|
if data_args.task != "fleurs-asr":
|
||||||
|
raise ValueError("--language_group should only be used with --task=fleurs-asr")
|
||||||
|
if data_args.language != "all":
|
||||||
|
raise ValueError("--language_group should only be used with --language=all")
|
||||||
|
|
||||||
if data_args.target_column_name is None:
|
if data_args.target_column_name is None:
|
||||||
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
|
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
|
||||||
@@ -502,11 +520,23 @@ def main():
|
|||||||
if data_args.max_predict_samples is not None:
|
if data_args.max_predict_samples is not None:
|
||||||
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
|
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
|
||||||
|
|
||||||
|
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
|
||||||
if not is_text_target:
|
if not is_text_target:
|
||||||
label_list = next(iter(raw_datasets.values())).features[target_column_name].names
|
label_list = next(iter(raw_datasets.values())).features[target_column_name].names
|
||||||
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
|
|
||||||
num_labels = len(label_list)
|
num_labels = len(label_list)
|
||||||
|
|
||||||
|
num_workers = data_args.preprocessing_num_workers
|
||||||
|
|
||||||
|
lang_group = data_args.language_group
|
||||||
|
if lang_group is not None:
|
||||||
|
with training_args.main_process_first(desc="language group filter"):
|
||||||
|
lang_group_id = next(iter(raw_datasets.values())).features["lang_group_id"].str2int(lang_group)
|
||||||
|
raw_datasets = raw_datasets.filter(
|
||||||
|
lambda lang_group: lang_group == lang_group_id,
|
||||||
|
num_proc=num_workers,
|
||||||
|
input_columns=["lang_group_id"],
|
||||||
|
)
|
||||||
|
|
||||||
# 2. We remove some special characters from the datasets
|
# 2. We remove some special characters from the datasets
|
||||||
# that make training complicated and do not help in transcribing the speech
|
# that make training complicated and do not help in transcribing the speech
|
||||||
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
|
||||||
@@ -616,6 +646,7 @@ def main():
|
|||||||
"mask_feature_length": model_args.mask_feature_length,
|
"mask_feature_length": model_args.mask_feature_length,
|
||||||
"gradient_checkpointing": training_args.gradient_checkpointing,
|
"gradient_checkpointing": training_args.gradient_checkpointing,
|
||||||
"layerdrop": model_args.layerdrop,
|
"layerdrop": model_args.layerdrop,
|
||||||
|
"ctc_zero_infinity": model_args.ctc_zero_infinity,
|
||||||
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
"ctc_loss_reduction": model_args.ctc_loss_reduction,
|
||||||
"activation_dropout": model_args.activation_dropout,
|
"activation_dropout": model_args.activation_dropout,
|
||||||
}
|
}
|
||||||
@@ -675,7 +706,6 @@ def main():
|
|||||||
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
|
||||||
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
|
||||||
audio_column_name = data_args.audio_column_name
|
audio_column_name = data_args.audio_column_name
|
||||||
num_workers = data_args.preprocessing_num_workers
|
|
||||||
|
|
||||||
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
|
||||||
phoneme_language = data_args.phoneme_language
|
phoneme_language = data_args.phoneme_language
|
||||||
@@ -740,13 +770,13 @@ def main():
|
|||||||
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
|
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
|
||||||
return
|
return
|
||||||
|
|
||||||
def compute_asr_metric(pred):
|
def asr_logits_argmax(logits, labels):
|
||||||
pred_logits = pred.predictions
|
return logits.argmax(dim=-1)
|
||||||
pred_ids = np.argmax(pred_logits, axis=-1)
|
|
||||||
|
|
||||||
|
def compute_asr_metric(pred):
|
||||||
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
|
||||||
|
|
||||||
pred_str = tokenizer.batch_decode(pred_ids)
|
pred_str = tokenizer.batch_decode(pred.predictions)
|
||||||
# we do not want to group tokens when computing the metrics
|
# we do not want to group tokens when computing the metrics
|
||||||
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
|
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
|
||||||
|
|
||||||
@@ -783,6 +813,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
preprocess_logits_for_metrics=asr_logits_argmax if training_args.predict_with_generate else None,
|
||||||
compute_metrics=compute_asr_metric if training_args.predict_with_generate else None,
|
compute_metrics=compute_asr_metric if training_args.predict_with_generate else None,
|
||||||
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
||||||
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
||||||
@@ -793,6 +824,7 @@ def main():
|
|||||||
model=model,
|
model=model,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
preprocess_logits_for_metrics=asr_logits_argmax if is_text_target else None,
|
||||||
compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric,
|
compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric,
|
||||||
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
|
||||||
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
|
||||||
@@ -837,11 +869,17 @@ def main():
|
|||||||
average_metrics = defaultdict(list)
|
average_metrics = defaultdict(list)
|
||||||
for lang_id in range(len(lang_list)):
|
for lang_id in range(len(lang_list)):
|
||||||
lang_name = lang_list[lang_id]
|
lang_name = lang_list[lang_id]
|
||||||
lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id)
|
with training_args.main_process_first(desc="per-language dataset filter"):
|
||||||
|
lang_dataset = vectorized_datasets["predict"].filter(
|
||||||
|
lambda lang: lang == lang_id,
|
||||||
|
num_proc=num_workers,
|
||||||
|
input_columns=["lang"],
|
||||||
|
)
|
||||||
lang_metrics = trainer.evaluate(lang_dataset)
|
lang_metrics = trainer.evaluate(lang_dataset)
|
||||||
|
redundant_metrics = ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second", "eval_epoch"]
|
||||||
for metric_name, value in lang_metrics.items():
|
for metric_name, value in lang_metrics.items():
|
||||||
average_metrics[metric_name].append(value)
|
average_metrics[metric_name].append(value)
|
||||||
if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]:
|
if metric_name not in redundant_metrics:
|
||||||
metrics[f"{metric_name}_{lang_name}"] = value
|
metrics[f"{metric_name}_{lang_name}"] = value
|
||||||
for metric_name, value in average_metrics.items():
|
for metric_name, value in average_metrics.items():
|
||||||
metrics[metric_name] = np.mean(value)
|
metrics[metric_name] = np.mean(value)
|
||||||
|
|||||||
Reference in New Issue
Block a user