Store FLOS as floats to avoid overflow. (#10213)
This commit is contained in:
@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# with enforced DataParallel
|
||||
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||
|
||||
trainer.train()
|
||||
self.assertTrue(isinstance(trainer.state.total_flos, float))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_optuna
|
||||
|
||||
Reference in New Issue
Block a user