Only log total_flos at the end of training (#7981)
* Only log total_flos at the end of training * Fix test
This commit is contained in:
@@ -830,6 +830,10 @@ class Trainer:
|
|||||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||||
self.model.load_state_dict(state_dict)
|
self.model.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
if self._total_flos is not None:
|
||||||
|
self.store_flos()
|
||||||
|
self.log({"total_flos": self.state.total_flos})
|
||||||
|
|
||||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
|
||||||
@@ -1013,9 +1017,6 @@ class Trainer:
|
|||||||
return self._log(logs)
|
return self._log(logs)
|
||||||
if self.state.epoch is not None:
|
if self.state.epoch is not None:
|
||||||
logs["epoch"] = self.state.epoch
|
logs["epoch"] = self.state.epoch
|
||||||
if self._total_flos is not None:
|
|
||||||
self.store_flos()
|
|
||||||
logs["total_flos"] = self.state.total_flos
|
|
||||||
|
|
||||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||||
output = {**logs, **{"step": self.state.global_step}}
|
output = {**logs, **{"step": self.state.global_step}}
|
||||||
|
|||||||
@@ -114,12 +114,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
|||||||
metrics = copy.deepcopy(metrics)
|
metrics = copy.deepcopy(metrics)
|
||||||
loss = metrics.pop("eval_loss", None)
|
loss = metrics.pop("eval_loss", None)
|
||||||
_ = metrics.pop("epoch", None)
|
_ = metrics.pop("epoch", None)
|
||||||
_ = metrics.pop("total_flos", None)
|
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||||
if len(metrics) != 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Metrics contains more entries than just 'eval_loss', 'epoch' and 'total_flos', please provide your own compute_objective function."
|
|
||||||
)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
def default_hp_space_optuna(trial) -> Dict[str, float]:
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class TrainerCallbackTest(unittest.TestCase):
|
|||||||
expected_events.append("on_epoch_end")
|
expected_events.append("on_epoch_end")
|
||||||
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
||||||
expected_events += evaluation_events.copy()
|
expected_events += evaluation_events.copy()
|
||||||
expected_events.append("on_train_end")
|
expected_events += ["on_log", "on_train_end"]
|
||||||
return expected_events
|
return expected_events
|
||||||
|
|
||||||
def test_init_callback(self):
|
def test_init_callback(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user