[Trainer] add train loss and flops metrics reports (#11980)

* add train loss and flops metrics reports

* consistency

* add train_loss to skip keys

* restore on_train_end call timing
This commit is contained in:
Stas Bekman
2021-06-01 15:58:31 -07:00
committed by GitHub
parent 7ec596ecda
commit 4ba203d9d3
2 changed files with 14 additions and 12 deletions

View File

@@ -311,13 +311,11 @@ class TrainerIntegrationCommon:
log_history = state.pop("log_history", None)
log_history1 = state1.pop("log_history", None)
self.assertEqual(state, state1)
skip_log_keys = ["train_runtime", "train_samples_per_second", "train_steps_per_second", "train_loss"]
for log, log1 in zip(log_history, log_history1):
_ = log.pop("train_runtime", None)
_ = log1.pop("train_runtime", None)
_ = log.pop("train_samples_per_second", None)
_ = log1.pop("train_samples_per_second", None)
_ = log.pop("train_steps_per_second", None)
_ = log1.pop("train_steps_per_second", None)
for key in skip_log_keys:
_ = log.pop(key, None)
_ = log1.pop(key, None)
self.assertEqual(log, log1)