Store FLOS as floats to avoid overflow. (#10213)

This commit is contained in:
Sylvain Gugger
2021-02-16 11:15:15 -05:00
committed by GitHub
parent df1b0fb54d
commit 7169d1ea7b
3 changed files with 8 additions and 4 deletions

View File

@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase):
# with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
trainer.train()
self.assertTrue(isinstance(trainer.state.total_flos, float))
@require_torch
@require_optuna