From 8b332a6a160c6df82e4267aaf118d87377d78a67 Mon Sep 17 00:00:00 2001 From: neverix Date: Fri, 8 Jul 2022 23:44:24 +0300 Subject: [PATCH] Make predict() close progress bars after finishing (#17952) (#18078) * Make Trainer.predict call on_evaluate (#17952) * Add on_predict * Small fix * Small and different fix * Add tests --- src/transformers/trainer.py | 1 + src/transformers/trainer_callback.py | 15 +++++++++++++++ src/transformers/utils/notebook.py | 5 +++++ tests/trainer/test_trainer_callback.py | 3 +++ 4 files changed, 24 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e817a9e141..0bebc8626b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2713,6 +2713,7 @@ class Trainer: ) ) + self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) self._memory_tracker.stop_and_update_metrics(output.metrics) return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 06875b74e1..8749e5f3f5 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -262,6 +262,12 @@ class TrainerCallback: """ pass + def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs): + """ + Event called after a successful prediction. + """ + pass + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """ Event called after a checkpoint save. @@ -372,6 +378,9 @@ class CallbackHandler(TrainerCallback): control.should_evaluate = False return self.call_event("on_evaluate", args, state, control, metrics=metrics) + def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + return self.call_event("on_predict", args, state, control, metrics=metrics) + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): control.should_save = False return self.call_event("on_save", args, state, control) @@ -484,6 +493,12 @@ class ProgressCallback(TrainerCallback): self.prediction_bar.close() self.prediction_bar = None + def on_predict(self, args, state, control, **kwargs): + if state.is_local_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + def on_log(self, args, state, control, logs=None, **kwargs): if state.is_local_process_zero and self.training_bar is not None: _ = logs.pop("total_flos", None) diff --git a/src/transformers/utils/notebook.py b/src/transformers/utils/notebook.py index f671ad737c..8d81d76c4f 100644 --- a/src/transformers/utils/notebook.py +++ b/src/transformers/utils/notebook.py @@ -307,6 +307,11 @@ class NotebookProgressCallback(TrainerCallback): else: self.prediction_bar.update(self.prediction_bar.value + 1) + def on_predict(self, args, state, control, **kwargs): + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + def on_log(self, args, state, control, logs=None, **kwargs): # Only for when there is no evaluation if args.evaluation_strategy == IntervalStrategy.NO and "loss" in logs: diff --git a/tests/trainer/test_trainer_callback.py b/tests/trainer/test_trainer_callback.py index a7daee4fd0..a88ca1cb0d 100644 --- a/tests/trainer/test_trainer_callback.py +++ b/tests/trainer/test_trainer_callback.py @@ -66,6 +66,9 @@ class MyTestTrainerCallback(TrainerCallback): def on_evaluate(self, args, state, control, **kwargs): self.events.append("on_evaluate") + def on_predict(self, args, state, control, **kwargs): + self.events.append("on_predict") + def on_save(self, args, state, control, **kwargs): self.events.append("on_save")