Add possibility to evaluate every epoch (#7302)

* Add possibility to evaluate every epoch

* Remove multitype arg

* Remove needless import

* Use a proper enum

* Apply suggestions from @LysandreJik

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* One else and formatting

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-09-22 09:52:29 -04:00
committed by GitHub
parent 21ca148090
commit 89edf504bf
3 changed files with 58 additions and 9 deletions

View File

@@ -39,6 +39,7 @@ from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
EvaluationStrategy,
HPSearchBackend,
PredictionOutput,
TrainOutput,
@@ -782,7 +783,10 @@ class Trainer:
self.log(logs)
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
if (
self.args.evaluation_strategy == EvaluationStrategy.STEPS
and self.global_step % self.args.eval_steps == 0
):
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
@@ -820,6 +824,9 @@ class Trainer:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
epoch_pbar.update(1)
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
metrics = self.evaluate()
self._report_to_hp_search(trial, epoch, metrics)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
break
epoch_pbar.close()