From 7169d1ea7b571bf303caf5c3c35e584f47cb398c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 16 Feb 2021 11:15:15 -0500 Subject: [PATCH] Store FLOS as floats to avoid overflow. (#10213) --- src/transformers/trainer.py | 2 +- src/transformers/trainer_callback.py | 7 ++++--- tests/test_trainer.py | 3 +++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index b0f7e694ea..44e210c153 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -959,7 +959,7 @@ class Trainer: tr_loss += self.training_step(model, inputs) else: tr_loss += self.training_step(model, inputs) - self._total_flos += self.floating_point_ops(inputs) + self._total_flos += float(self.floating_point_ops(inputs)) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index ea2ea3cd82..34027dc9e1 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -52,8 +52,9 @@ class TrainerState: During training, represents the number of update steps completed. max_steps (:obj:`int`, `optional`, defaults to 0): The number of update steps to do during the current training. - total_flos (:obj:`int`, `optional`, defaults to 0): - The total number of floating operations done by the model since the beginning of training. + total_flos (:obj:`float`, `optional`, defaults to 0): + The total number of floating operations done by the model since the beginning of training (stored as floats + to avoid overflow). log_history (:obj:`List[Dict[str, float]]`, `optional`): The list of logs done since the beginning of training. best_metric (:obj:`float`, `optional`): @@ -76,7 +77,7 @@ class TrainerState: global_step: int = 0 max_steps: int = 0 num_train_epochs: int = 0 - total_flos: int = 0 + total_flos: float = 0 log_history: List[Dict[str, float]] = None best_metric: Optional[float] = None best_model_checkpoint: Optional[str] = None diff --git a/tests/test_trainer.py b/tests/test_trainer.py index d382dbc40b..ed8c337f7c 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase): # with enforced DataParallel assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model)) + trainer.train() + self.assertTrue(isinstance(trainer.state.total_flos, float)) + @require_torch @require_optuna