modify qa-trainer (#11872)

* modify qa-trainer

* fix flax model
This commit is contained in:
Fan Zhang
2021-06-01 20:28:41 +08:00
committed by GitHub
parent 9ec0f01b6c
commit 7e73601f32
25 changed files with 57 additions and 49 deletions

View File

@@ -692,7 +692,11 @@ def main():
if completed_steps >= args.max_train_steps:
break
# Validation
# Evaluation
logger.info("***** Running Evaluation *****")
logger.info(f" Num examples = {len(eval_dataset)}")
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
all_start_logits = []
all_end_logits = []
for step, batch in enumerate(eval_dataloader):
@@ -725,6 +729,10 @@ def main():
# Prediction
if args.do_predict:
logger.info("***** Running Prediction *****")
logger.info(f" Num examples = {len(predict_dataset)}")
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
all_start_logits = []
all_end_logits = []
for step, batch in enumerate(predict_dataloader):