[Trainer] add train loss and flops metrics reports (#11980)
* add train loss and flops metrics reports * consistency * add train_loss to skip keys * restore on_train_end call timing
This commit is contained in:
@@ -311,13 +311,11 @@ class TrainerIntegrationCommon:
|
||||
log_history = state.pop("log_history", None)
|
||||
log_history1 = state1.pop("log_history", None)
|
||||
self.assertEqual(state, state1)
|
||||
skip_log_keys = ["train_runtime", "train_samples_per_second", "train_steps_per_second", "train_loss"]
|
||||
for log, log1 in zip(log_history, log_history1):
|
||||
_ = log.pop("train_runtime", None)
|
||||
_ = log1.pop("train_runtime", None)
|
||||
_ = log.pop("train_samples_per_second", None)
|
||||
_ = log1.pop("train_samples_per_second", None)
|
||||
_ = log.pop("train_steps_per_second", None)
|
||||
_ = log1.pop("train_steps_per_second", None)
|
||||
for key in skip_log_keys:
|
||||
_ = log.pop(key, None)
|
||||
_ = log1.pop(key, None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user