allow an already created tensorboard SummaryWriter be passed to Trainer
This commit is contained in:
@@ -123,6 +123,7 @@ class Trainer:
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
prediction_loss_only=False,
|
||||
tb_writer: Optional["SummaryWriter"] = None,
|
||||
):
|
||||
"""
|
||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||
@@ -142,7 +143,9 @@ class Trainer:
|
||||
self.eval_dataset = eval_dataset
|
||||
self.compute_metrics = compute_metrics
|
||||
self.prediction_loss_only = prediction_loss_only
|
||||
if is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
||||
if tb_writer is not None:
|
||||
self.tb_writer = tb_writer
|
||||
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
||||
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
||||
if not is_tensorboard_available():
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user