Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
This commit is contained in:
Sylvain Gugger
2020-12-18 15:10:39 -05:00
committed by GitHub
parent 9a25c5bd3a
commit 1198ba8fba
6 changed files with 76 additions and 49 deletions

View File

@@ -265,6 +265,21 @@ class TrainerIntegrationTest(unittest.TestCase):
metrics = trainer.evaluate()
self.assertEqual(metrics[metric], best_value)
def check_trainer_state_are_the_same(self, trainer_state, trainer_state1):
# We'll pop things so operate on copies.
state = trainer_state.copy()
state1 = trainer_state1.copy()
# Log history main contain different logs for the time metrics (after resuming a training).
log_history = state.pop("log_history", None)
log_history1 = state1.pop("log_history", None)
self.assertEqual(state, state1)
for log, log1 in zip(log_history, log_history1):
_ = log.pop("train_runtime", None)
_ = log1.pop("train_runtime", None)
_ = log.pop("train_samples_per_second", None)
_ = log1.pop("train_samples_per_second", None)
self.assertEqual(log, log1)
def test_trainer_works_with_dict(self):
# Edge case because Apex with mode O2 will change our models to return dicts. This test checks it doesn't break
# anything.
@@ -552,7 +567,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
@@ -566,7 +581,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
# With a regular model that is not a PreTrainedModel
with tempfile.TemporaryDirectory() as tmpdir:
@@ -590,7 +605,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
# Now check with a later checkpoint that it also works when we span over one epoch
checkpoint = os.path.join(tmpdir, "checkpoint-15")
@@ -606,7 +621,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
def test_resume_training_with_gradient_accumulation(self):
if torch.cuda.device_count() > 2:
@@ -638,7 +653,7 @@ class TrainerIntegrationTest(unittest.TestCase):
state1 = dataclasses.asdict(trainer.state)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
self.assertEqual(state, state1)
self.check_trainer_state_are_the_same(state, state1)
def test_load_best_model_at_end(self):
total = int(self.n_epochs * 64 / self.batch_size)