Fix GLUE MNLI when using max_eval_samples (#18722)

This commit is contained in:
Leandro von Werra
2022-09-21 11:33:22 +04:00
committed by GitHub
parent 18643ff29a
commit ef6741fe65

View File

@@ -549,7 +549,11 @@ def main():
eval_datasets = [eval_dataset] eval_datasets = [eval_dataset]
if data_args.task_name == "mnli": if data_args.task_name == "mnli":
tasks.append("mnli-mm") tasks.append("mnli-mm")
eval_datasets.append(raw_datasets["validation_mismatched"]) valid_mm_dataset = raw_datasets["validation_mismatched"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples)
valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples))
eval_datasets.append(valid_mm_dataset)
combined = {} combined = {}
for eval_dataset, task in zip(eval_datasets, tasks): for eval_dataset, task in zip(eval_datasets, tasks):