@@ -692,7 +692,11 @@ def main():
|
||||
if completed_steps >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# Validation
|
||||
# Evaluation
|
||||
logger.info("***** Running Evaluation *****")
|
||||
logger.info(f" Num examples = {len(eval_dataset)}")
|
||||
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
|
||||
|
||||
all_start_logits = []
|
||||
all_end_logits = []
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
@@ -725,6 +729,10 @@ def main():
|
||||
|
||||
# Prediction
|
||||
if args.do_predict:
|
||||
logger.info("***** Running Prediction *****")
|
||||
logger.info(f" Num examples = {len(predict_dataset)}")
|
||||
logger.info(f" Batch size = {args.per_device_eval_batch_size}")
|
||||
|
||||
all_start_logits = []
|
||||
all_end_logits = []
|
||||
for step, batch in enumerate(predict_dataloader):
|
||||
|
||||
Reference in New Issue
Block a user