Distributed eval: SequentialDistributedSampler + gather all results (#4243)

* Distributed eval: SequentialDistributedSampler + gather all results

* For consistency only write to disk from world_master

Close https://github.com/huggingface/transformers/issues/4272

* Working distributed eval

* Hook into scripts

* Fix #3721 again

* TPU.mesh_reduce: stay in tensor space

Thanks @jysohn23

* Just a small comment

* whitespace

* torch.hub: pip install packaging

* Add test scenarii
This commit is contained in:
Julien Chaumond
2020-05-18 22:02:39 -04:00
committed by GitHub
parent 4c06893610
commit 5e7fe8b585
7 changed files with 280 additions and 83 deletions

View File

@@ -235,22 +235,23 @@ def main():
# Evaluation
results = {}
if training_args.do_eval and training_args.local_rank in [-1, 0]:
if training_args.do_eval:
logger.info("*** Evaluate ***")
result = trainer.evaluate()
output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
results.update(result)
# Predict
if training_args.do_predict and training_args.local_rank in [-1, 0]:
if training_args.do_predict:
test_dataset = NerDataset(
data_dir=data_args.data_dir,
tokenizer=tokenizer,
@@ -265,26 +266,30 @@ def main():
preds_list, _ = align_predictions(predictions, label_ids)
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
with open(output_test_results_file, "w") as writer:
for key, value in metrics.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
if trainer.is_world_master():
with open(output_test_results_file, "w") as writer:
for key, value in metrics.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
# Save predictions
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
if trainer.is_world_master():
with open(output_test_predictions_file, "w") as writer:
with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not preds_list[example_id]:
example_id += 1
elif preds_list[example_id]:
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logger.warning(
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
)
return results