From 858b1d1e5a4c2b3868be6deceeca8bf0a8f3f68f Mon Sep 17 00:00:00 2001 From: jaymody Date: Mon, 27 Apr 2020 14:42:57 -0400 Subject: [PATCH] allow an already created tensorboard SummaryWriter be passed to Trainer --- src/transformers/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a7b2bae457..661c60f88e 100644 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -123,6 +123,7 @@ class Trainer: eval_dataset: Optional[Dataset] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, prediction_loss_only=False, + tb_writer: Optional["SummaryWriter"] = None, ): """ Trainer is a simple but feature-complete training and eval loop for PyTorch, @@ -142,7 +143,9 @@ class Trainer: self.eval_dataset = eval_dataset self.compute_metrics = compute_metrics self.prediction_loss_only = prediction_loss_only - if is_tensorboard_available() and self.args.local_rank in [-1, 0]: + if tb_writer is not None: + self.tb_writer = tb_writer + elif is_tensorboard_available() and self.args.local_rank in [-1, 0]: self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) if not is_tensorboard_available(): logger.warning(