Reload checkpoint (#7984)

* Fix checkpoint loading in Trainer

* Fix typo
This commit is contained in:
Sylvain Gugger
2020-10-22 15:48:52 -04:00
committed by GitHub
parent 467573ddde
commit 5ae935d233
3 changed files with 35 additions and 17 deletions

View File

@@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
def on_train_begin(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar = tqdm(total=state.max_steps)
self.current_step = 0
def on_step_end(self, args, state, control, **kwargs):
if state.is_local_process_zero:
self.training_bar.update(1)
self.training_bar.update(state.global_step - self.current_step)
self.current_step = state.global_step
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
if state.is_local_process_zero: