Store FLOS as floats to avoid overflow. (#10213)
This commit is contained in:
@@ -959,7 +959,7 @@ class Trainer:
|
|||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss += self.training_step(model, inputs)
|
||||||
else:
|
else:
|
||||||
tr_loss += self.training_step(model, inputs)
|
tr_loss += self.training_step(model, inputs)
|
||||||
self._total_flos += self.floating_point_ops(inputs)
|
self._total_flos += float(self.floating_point_ops(inputs))
|
||||||
|
|
||||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||||
|
|||||||
@@ -52,8 +52,9 @@ class TrainerState:
|
|||||||
During training, represents the number of update steps completed.
|
During training, represents the number of update steps completed.
|
||||||
max_steps (:obj:`int`, `optional`, defaults to 0):
|
max_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
The number of update steps to do during the current training.
|
The number of update steps to do during the current training.
|
||||||
total_flos (:obj:`int`, `optional`, defaults to 0):
|
total_flos (:obj:`float`, `optional`, defaults to 0):
|
||||||
The total number of floating operations done by the model since the beginning of training.
|
The total number of floating operations done by the model since the beginning of training (stored as floats
|
||||||
|
to avoid overflow).
|
||||||
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
log_history (:obj:`List[Dict[str, float]]`, `optional`):
|
||||||
The list of logs done since the beginning of training.
|
The list of logs done since the beginning of training.
|
||||||
best_metric (:obj:`float`, `optional`):
|
best_metric (:obj:`float`, `optional`):
|
||||||
@@ -76,7 +77,7 @@ class TrainerState:
|
|||||||
global_step: int = 0
|
global_step: int = 0
|
||||||
max_steps: int = 0
|
max_steps: int = 0
|
||||||
num_train_epochs: int = 0
|
num_train_epochs: int = 0
|
||||||
total_flos: int = 0
|
total_flos: float = 0
|
||||||
log_history: List[Dict[str, float]] = None
|
log_history: List[Dict[str, float]] = None
|
||||||
best_metric: Optional[float] = None
|
best_metric: Optional[float] = None
|
||||||
best_model_checkpoint: Optional[str] = None
|
best_model_checkpoint: Optional[str] = None
|
||||||
|
|||||||
@@ -881,6 +881,9 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
# with enforced DataParallel
|
# with enforced DataParallel
|
||||||
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
assert_flos_extraction(trainer, torch.nn.DataParallel(trainer.model))
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
self.assertTrue(isinstance(trainer.state.total_flos, float))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_optuna
|
@require_optuna
|
||||||
|
|||||||
Reference in New Issue
Block a user