Reload checkpoint (#7984)
* Fix checkpoint loading in Trainer * Fix typo
This commit is contained in:
@@ -436,10 +436,12 @@ class ProgressCallback(TrainerCallback):
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar = tqdm(total=state.max_steps)
|
||||
self.current_step = 0
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
self.training_bar.update(1)
|
||||
self.training_bar.update(state.global_step - self.current_step)
|
||||
self.current_step = state.global_step
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_local_process_zero:
|
||||
|
||||
Reference in New Issue
Block a user