[Examples] Fixes inconsistency around eval vs val and predict vs test (#11380)

* added changes for uniformity

* modified files

* corrected typo

* fixed qa scripts

* fix typos

* fixed predict typo in qa no trainer

* fixed test file

* reverted trainer changes

* reverted trainer changes in custom exmaples

* updated readme

* added changes in deepspeed test

* added changes for predict and eval
This commit is contained in:
Bhadresh Savani
2021-04-26 17:24:31 +01:00
committed by GitHub
parent 7959d83599
commit 1d30ec95c7
19 changed files with 288 additions and 276 deletions

View File

@@ -100,17 +100,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
@@ -390,15 +390,15 @@ def main():
if "validation" not in datasets and "validation_matched" not in datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
if "test" not in datasets and "test_matched" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
# Log a few random samples from the training set:
if training_args.do_train:
@@ -483,32 +483,34 @@ def main():
for eval_dataset, task in zip(eval_datasets, tasks):
metrics = trainer.evaluate(eval_dataset=eval_dataset)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
if training_args.do_predict:
logger.info("*** Test ***")
logger.info("*** Predict ***")
# Loop to handle MNLI double evaluation (matched, mis-matched)
tasks = [data_args.task_name]
test_datasets = [test_dataset]
predict_datasets = [predict_dataset]
if data_args.task_name == "mnli":
tasks.append("mnli-mm")
test_datasets.append(datasets["test_mismatched"])
predict_datasets.append(datasets["test_mismatched"])
for test_dataset, task in zip(test_datasets, tasks):
for predict_dataset, task in zip(predict_datasets, tasks):
# Removing the `label` columns because it contains -1 and Trainer won't like that.
test_dataset.remove_columns_("label")
predictions = trainer.predict(test_dataset=test_dataset).predictions
predict_dataset.remove_columns_("label")
predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
logger.info(f"***** Test results {task} *****")
with open(output_predict_file, "w") as writer:
logger.info(f"***** Predict results {task} *****")
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
if is_regression:

View File

@@ -84,17 +84,17 @@ class DataTrainingArguments:
"value if set."
},
)
max_val_samples: Optional[int] = field(
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
@@ -221,8 +221,8 @@ def main():
label_list = eval_dataset.features["label"].names
if training_args.do_predict:
test_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
label_list = test_dataset.features["label"].names
predict_dataset = load_dataset("xnli", model_args.language, split="test", cache_dir=model_args.cache_dir)
label_list = predict_dataset.features["label"].names
# Labels
num_labels = len(label_list)
@@ -286,8 +286,8 @@ def main():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_eval:
if data_args.max_val_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
if data_args.max_eval_samples is not None:
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
@@ -295,9 +295,9 @@ def main():
)
if training_args.do_predict:
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
if data_args.max_predict_samples is not None:
predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
predict_dataset = predict_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
@@ -360,8 +360,8 @@ def main():
logger.info("*** Evaluate ***")
metrics = trainer.evaluate(eval_dataset=eval_dataset)
max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
@@ -369,18 +369,20 @@ def main():
# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)
predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
max_predict_samples = (
data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
)
metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)
trainer.log_metrics("predict", metrics)
trainer.save_metrics("predict", metrics)
predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
output_predict_file = os.path.join(training_args.output_dir, "predictions.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
with open(output_predict_file, "w") as writer:
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
item = label_list[item]