allow an already created tensorboard SummaryWriter be passed to Trainer

This commit is contained in:
jaymody
2020-04-27 14:42:57 -04:00
committed by Julien Chaumond
parent 8e67573a64
commit 858b1d1e5a

View File

@@ -123,6 +123,7 @@ class Trainer:
eval_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
prediction_loss_only=False, prediction_loss_only=False,
tb_writer: Optional["SummaryWriter"] = None,
): ):
""" """
Trainer is a simple but feature-complete training and eval loop for PyTorch, Trainer is a simple but feature-complete training and eval loop for PyTorch,
@@ -142,7 +143,9 @@ class Trainer:
self.eval_dataset = eval_dataset self.eval_dataset = eval_dataset
self.compute_metrics = compute_metrics self.compute_metrics = compute_metrics
self.prediction_loss_only = prediction_loss_only self.prediction_loss_only = prediction_loss_only
if is_tensorboard_available() and self.args.local_rank in [-1, 0]: if tb_writer is not None:
self.tb_writer = tb_writer
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
if not is_tensorboard_available(): if not is_tensorboard_available():
logger.warning( logger.warning(