From 2ce3ddab2db6e46606e8d5cc0a2c05795d0ccd97 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 15 Oct 2020 10:30:34 -0400 Subject: [PATCH] Small fixes to NotebookProgressCallback (#7813) --- src/transformers/file_utils.py | 2 +- src/transformers/utils/notebook.py | 27 +++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 03668662b4..ea319bcae9 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -153,7 +153,7 @@ try: import IPython # noqa: F401 _in_notebook = True -except: # noqa: E722 +except (AttributeError, ImportError, KeyError): _in_notebook = False diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index ad8551c19f..4dd02e26f3 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -19,6 +19,7 @@ from typing import Optional import IPython.display as disp from ..trainer_callback import TrainerCallback +from ..trainer_utils import EvaluationStrategy def format_time(t): @@ -146,7 +147,7 @@ class NotebookProgressBar: self.first_calls = self.warmup self.wait_for = 1 self.update_bar(value) - elif value <= self.last_value: + elif value <= self.last_value and not force_update: return elif force_update or self.first_calls > 0 or value >= min(self.last_value + self.wait_for, self.total): if self.first_calls > 0: @@ -272,17 +273,25 @@ class NotebookProgressCallback(TrainerCallback): def __init__(self): self.training_tracker = None self.prediction_bar = None + self._force_next_update = False def on_train_begin(self, args, state, control, **kwargs): - self.first_column = "Epoch" if args.max_steps <= 0 else "Step" + self.first_column = "Epoch" if args.evaluation_strategy == EvaluationStrategy.EPOCH else "Step" self.training_loss = 0 self.last_log = 0 - column_names = [self.first_column] + ["Training Loss", "Validation Loss"] + column_names = [self.first_column] + ["Training Loss"] + if args.evaluation_strategy != EvaluationStrategy.NO: + column_names.append("Validation Loss") self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) def on_step_end(self, args, state, control, **kwargs): epoch = int(state.epoch) if int(state.epoch) == state.epoch else f"{state.epoch:.2f}" - self.training_tracker.update(state.global_step + 1, comment=f"Epoch {epoch}/{state.num_train_epochs}") + self.training_tracker.update( + state.global_step + 1, + comment=f"Epoch {epoch}/{state.num_train_epochs}", + force_update=self._force_next_update, + ) + self._force_next_update = False def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): if self.prediction_bar is None: @@ -294,6 +303,14 @@ class NotebookProgressCallback(TrainerCallback): else: self.prediction_bar.update(self.prediction_bar.value + 1) + def on_log(self, args, state, control, logs=None, **kwargs): + # Only for when there is no evaluation + if args.evaluation_strategy == EvaluationStrategy.NO and "loss" in logs: + values = {"Training Loss": logs["loss"]} + # First column is necessarily Step sine we're not in epoch eval strategy + values["Step"] = state.global_step + self.training_tracker.write_line(values) + def on_evaluate(self, args, state, control, metrics=None, **kwargs): if self.training_tracker is not None: values = {"Training Loss": "No log"} @@ -319,6 +336,8 @@ class NotebookProgressCallback(TrainerCallback): self.training_tracker.write_line(values) self.training_tracker.remove_child() self.prediction_bar = None + # Evaluation takes a long time so we should force the next update. + self._force_next_update = True def on_train_end(self, args, state, control, **kwargs): self.training_tracker.update(