Store FLOS as floats to avoid overflow. (#10213)

This commit is contained in:
Sylvain Gugger
2021-02-16 11:15:15 -05:00
committed by GitHub
parent df1b0fb54d
commit 7169d1ea7b
3 changed files with 8 additions and 4 deletions

View File

@@ -959,7 +959,7 @@ class Trainer:
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
else: else:
tr_loss += self.training_step(model, inputs) tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs) self._total_flos += float(self.floating_point_ops(inputs))
if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps # last step in epoch but step is always smaller than gradient_accumulation_steps

View File

@@ -52,8 +52,9 @@ class TrainerState:
During training, represents the number of update steps completed. During training, represents the number of update steps completed.
max_steps (:obj:`int`, `optional`, defaults to 0): max_steps (:obj:`int`, `optional`, defaults to 0):
The number of update steps to do during the current training. The number of update steps to do during the current training.
total_flos (:obj:`int`, `optional`, defaults to 0): total_flos (:obj:`float`, `optional`, defaults to 0):
The total number of floating operations done by the model since the beginning of training. The total number of floating operations done by the model since the beginning of training (stored as floats
to avoid overflow).
log_history (:obj:`List[Dict[str, float]]`, `optional`): log_history (:obj:`List[Dict[str, float]]`, `optional`):
The list of logs done since the beginning of training. The list of logs done since the beginning of training.
best_metric (:obj:`float`, `optional`): best_metric (:obj:`float`, `optional`):
@@ -76,7 +77,7 @@ class TrainerState:
global_step: int = 0 global_step: int = 0
max_steps: int = 0 max_steps: int = 0
num_train_epochs: int = 0 num_train_epochs: int = 0
total_flos: int = 0 total_flos: float = 0
log_history: List[Dict[str, float]] = None log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None best_metric: Optional[float] = None
best_model_checkpoint: Optional[str] = None best_model_checkpoint: Optional[str] = None

View File

@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase):
# with enforced DataParallel # with enforced DataParallel
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model)) assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
trainer.train()
self.assertTrue(isinstance(trainer.state.total_flos, float))
@require_torch @require_torch
@require_optuna @require_optuna