Small fixes to NotebookProgressCallback (#7813)
This commit is contained in:
@@ -153,7 +153,7 @@ try:
|
|||||||
import IPython # noqa: F401
|
import IPython # noqa: F401
|
||||||
|
|
||||||
_in_notebook = True
|
_in_notebook = True
|
||||||
except: # noqa: E722
|
except (AttributeError, ImportError, KeyError):
|
||||||
_in_notebook = False
|
_in_notebook = False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import Optional
|
|||||||
import IPython.display as disp
|
import IPython.display as disp
|
||||||
|
|
||||||
from ..trainer_callback import TrainerCallback
|
from ..trainer_callback import TrainerCallback
|
||||||
|
from ..trainer_utils import EvaluationStrategy
|
||||||
|
|
||||||
|
|
||||||
def format_time(t):
|
def format_time(t):
|
||||||
@@ -146,7 +147,7 @@ class NotebookProgressBar:
|
|||||||
self.first_calls = self.warmup
|
self.first_calls = self.warmup
|
||||||
self.wait_for = 1
|
self.wait_for = 1
|
||||||
self.update_bar(value)
|
self.update_bar(value)
|
||||||
elif value <= self.last_value:
|
elif value <= self.last_value and not force_update:
|
||||||
return
|
return
|
||||||
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
|
elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total):
|
||||||
if self.first_calls > 0:
|
if self.first_calls > 0:
|
||||||
@@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.training_tracker = None
|
self.training_tracker = None
|
||||||
self.prediction_bar = None
|
self.prediction_bar = None
|
||||||
|
self._force_next_update = False
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, **kwargs):
|
def on_train_begin(self, args, state, control, **kwargs):
|
||||||
self.first_column = "Epoch" if args.max_steps <= 0 else "Step"
|
self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step"
|
||||||
self.training_loss = 0
|
self.training_loss = 0
|
||||||
self.last_log = 0
|
self.last_log = 0
|
||||||
column_names = [self.first_column] + ["Training Loss", "Validation Loss"]
|
column_names = [self.first_column] + ["Training Loss"]
|
||||||
|
if args.evaluation_strategy != EvaluationStrategy.NO:
|
||||||
|
column_names.append("Validation Loss")
|
||||||
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
||||||
|
|
||||||
def on_step_end(self, args, state, control, **kwargs):
|
def on_step_end(self, args, state, control, **kwargs):
|
||||||
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
|
epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}"
|
||||||
self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}")
|
self.training_tracker.update(
|
||||||
|
state.global_step + 1,
|
||||||
|
comment=f"Epoch {epoch}/{state.num_train_epochs}",
|
||||||
|
force_update=self._force_next_update,
|
||||||
|
)
|
||||||
|
self._force_next_update = False
|
||||||
|
|
||||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||||
if self.prediction_bar is None:
|
if self.prediction_bar is None:
|
||||||
@@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
else:
|
else:
|
||||||
self.prediction_bar.update(self.prediction_bar.value + 1)
|
self.prediction_bar.update(self.prediction_bar.value + 1)
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
# Only for when there is no evaluation
|
||||||
|
if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs:
|
||||||
|
values = {"Training Loss": logs["loss"]}
|
||||||
|
# First column is necessarily Step sine we're not in epoch eval strategy
|
||||||
|
values["Step"] = state.global_step
|
||||||
|
self.training_tracker.write_line(values)
|
||||||
|
|
||||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||||
if self.training_tracker is not None:
|
if self.training_tracker is not None:
|
||||||
values = {"Training Loss": "No log"}
|
values = {"Training Loss": "No log"}
|
||||||
@@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback):
|
|||||||
self.training_tracker.write_line(values)
|
self.training_tracker.write_line(values)
|
||||||
self.training_tracker.remove_child()
|
self.training_tracker.remove_child()
|
||||||
self.prediction_bar = None
|
self.prediction_bar = None
|
||||||
|
# Evaluation takes a long time so we should force the next update.
|
||||||
|
self._force_next_update = True
|
||||||
|
|
||||||
def on_train_end(self, args, state, control, **kwargs):
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
self.training_tracker.update(
|
self.training_tracker.update(
|
||||||
|
|||||||
Reference in New Issue
Block a user