allow an already created tensorboard SummaryWriter be passed to Trainer
This commit is contained in:
@@ -123,6 +123,7 @@ class Trainer:
|
|||||||
eval_dataset: Optional[Dataset] = None,
|
eval_dataset: Optional[Dataset] = None,
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||||
prediction_loss_only=False,
|
prediction_loss_only=False,
|
||||||
|
tb_writer: Optional["SummaryWriter"] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
||||||
@@ -142,7 +143,9 @@ class Trainer:
|
|||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.prediction_loss_only = prediction_loss_only
|
self.prediction_loss_only = prediction_loss_only
|
||||||
if is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
if tb_writer is not None:
|
||||||
|
self.tb_writer = tb_writer
|
||||||
|
elif is_tensorboard_available() and self.args.local_rank in [-1, 0]:
|
||||||
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
||||||
if not is_tensorboard_available():
|
if not is_tensorboard_available():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
Reference in New Issue
Block a user