Add possibility to evaluate every epoch (#7302)
* Add possibility to evaluate every epoch * Remove multitype arg * Remove needless import * Use a proper enum * Apply suggestions from @LysandreJik Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * One else and formatting Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -39,6 +39,7 @@ from .trainer_utils import (
|
||||
PREFIX_CHECKPOINT_DIR,
|
||||
BestRun,
|
||||
EvalPrediction,
|
||||
EvaluationStrategy,
|
||||
HPSearchBackend,
|
||||
PredictionOutput,
|
||||
TrainOutput,
|
||||
@@ -782,7 +783,10 @@ class Trainer:
|
||||
|
||||
self.log(logs)
|
||||
|
||||
if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
|
||||
if (
|
||||
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
||||
and self.global_step % self.args.eval_steps == 0
|
||||
):
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
|
||||
@@ -820,6 +824,9 @@ class Trainer:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
epoch_pbar.update(1)
|
||||
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||
metrics = self.evaluate()
|
||||
self._report_to_hp_search(trial, epoch, metrics)
|
||||
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
||||
break
|
||||
epoch_pbar.close()
|
||||
|
||||
Reference in New Issue
Block a user