Adds predict stage for glue tasks, and generate result files which can be submitted to gluebenchmark.com (#4463)
* Adds predict stage for glue tasks, and generate result files which could be submitted to gluebenchmark.com website. * Use Split enum + always output the label name Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -419,7 +419,7 @@ def main():
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Prepare dataset for the GLUE task
|
||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, evaluate=True)
|
||||
eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
|
||||
if args.data_subset > 0:
|
||||
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
||||
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
||||
|
||||
@@ -135,7 +135,8 @@ def main():
|
||||
|
||||
# Get datasets
|
||||
train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
|
||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
|
||||
eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None
|
||||
test_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="test") if training_args.do_predict else None
|
||||
|
||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||
if output_mode == "classification":
|
||||
@@ -165,7 +166,7 @@ def main():
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
eval_results = {}
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
|
||||
@@ -173,10 +174,10 @@ def main():
|
||||
eval_datasets = [eval_dataset]
|
||||
if data_args.task_name == "mnli":
|
||||
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, evaluate=True))
|
||||
eval_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev"))
|
||||
|
||||
for eval_dataset in eval_datasets:
|
||||
result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
eval_result = trainer.evaluate(eval_dataset=eval_dataset)
|
||||
|
||||
output_eval_file = os.path.join(
|
||||
training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
|
||||
@@ -184,13 +185,38 @@ def main():
|
||||
if trainer.is_world_master():
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
|
||||
for key, value in result.items():
|
||||
for key, value in eval_result.items():
|
||||
logger.info(" %s = %s", key, value)
|
||||
writer.write("%s = %s\n" % (key, value))
|
||||
|
||||
results.update(result)
|
||||
eval_results.update(eval_result)
|
||||
|
||||
return results
|
||||
if training_args.do_predict:
|
||||
logging.info("*** Test ***")
|
||||
test_datasets = [test_dataset]
|
||||
if data_args.task_name == "mnli":
|
||||
mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
|
||||
test_datasets.append(GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test"))
|
||||
|
||||
for test_dataset in test_datasets:
|
||||
predictions = trainer.predict(test_dataset=test_dataset).predictions
|
||||
if output_mode == "classification":
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
|
||||
output_test_file = os.path.join(
|
||||
training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
|
||||
)
|
||||
if trainer.is_world_master():
|
||||
with open(output_test_file, "w") as writer:
|
||||
logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
|
||||
writer.write("index\tprediction\n")
|
||||
for index, item in enumerate(predictions):
|
||||
if output_mode == "regression":
|
||||
writer.write("%d\t%3.3f\n" % (index, item))
|
||||
else:
|
||||
item = test_dataset.get_labels()[item]
|
||||
writer.write("%d\t%s\n" % (index, item))
|
||||
return eval_results
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
||||
Reference in New Issue
Block a user