fix steps_in_epoch variable in trainer when using max_steps (#9969)

* fix steps_in_epoch variable when using max_steps

* redundant sentence

* Revert "redundant sentence"

This reverts commit ad5c0e9b6e66d65732dee2239cdc9c76dfa0dc5a.

* remove redundant sentence

Co-authored-by: wujindou <wujindou@sogou-inc.com>
This commit is contained in:
yylun
2021-02-03 22:30:37 +08:00
committed by GitHub
parent 3f77c26d74
commit 5442a11f5f
3 changed files with 6 additions and 5 deletions

View File

@@ -910,7 +910,11 @@ class Trainer:
if self.args.past_index >= 0:
self._past = None
steps_in_epoch = len(epoch_iterator) if train_dataset_is_sized else self.args.max_steps
steps_in_epoch = (
len(epoch_iterator)
if train_dataset_is_sized
else self.args.max_steps * self.args.gradient_accumulation_steps
)
self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
for step, inputs in enumerate(epoch_iterator):