Trainer: add logging through Weights & Biases (#3916)

* feat: add logging through Weights & Biases

* feat(wandb): make logging compatible with all scripts

* style(trainer.py): fix formatting

* [Trainer] Tweak wandb integration

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Boris Dayma
2020-05-04 21:42:27 -05:00
committed by GitHub
parent 858b1d1e5a
commit 818463ee8e
3 changed files with 41 additions and 2 deletions

1
.gitignore vendored
View File

@@ -131,6 +131,7 @@ proc_data
# examples # examples
runs runs
/runs_old /runs_old
/wandb
examples/runs examples/runs
# data # data

View File

@@ -52,6 +52,18 @@ def is_tensorboard_available():
return _has_tensorboard return _has_tensorboard
try:
import wandb
_has_wandb = True
except ImportError:
_has_wandb = False
def is_wandb_available():
return _has_wandb
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -151,6 +163,10 @@ class Trainer:
logger.warning( logger.warning(
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
) )
if not is_wandb_available():
logger.info(
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
)
set_seed(self.args.seed) set_seed(self.args.seed)
# Create output directory if needed # Create output directory if needed
if self.args.local_rank in [-1, 0]: if self.args.local_rank in [-1, 0]:
@@ -209,6 +225,12 @@ class Trainer:
) )
return optimizer, scheduler return optimizer, scheduler
def _setup_wandb(self):
# Start a wandb run and log config parameters
wandb.init(name=self.args.logging_dir, config=vars(self.args))
# keep track of model topology and gradients
# wandb.watch(self.model)
def train(self, model_path: Optional[str] = None): def train(self, model_path: Optional[str] = None):
""" """
Main training entry point. Main training entry point.
@@ -263,6 +285,9 @@ class Trainer:
if self.tb_writer is not None: if self.tb_writer is not None:
self.tb_writer.add_text("args", self.args.to_json_string()) self.tb_writer.add_text("args", self.args.to_json_string())
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
if is_wandb_available():
self._setup_wandb()
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
@@ -351,6 +376,9 @@ class Trainer:
if self.tb_writer: if self.tb_writer:
for k, v in logs.items(): for k, v in logs.items():
self.tb_writer.add_scalar(k, v, global_step) self.tb_writer.add_scalar(k, v, global_step)
if is_wandb_available():
wandb.log(logs, step=global_step)
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}})) epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0: if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
@@ -467,7 +495,7 @@ class Trainer:
shutil.rmtree(checkpoint) shutil.rmtree(checkpoint)
def evaluate( def evaluate(
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Run evaluation and return metrics. Run evaluation and return metrics.

View File

@@ -2,7 +2,7 @@ import dataclasses
import json import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Tuple from typing import Any, Dict, Optional, Tuple
from .file_utils import cached_property, is_torch_available, torch_required from .file_utils import cached_property, is_torch_available, torch_required
@@ -138,3 +138,13 @@ class TrainingArguments:
Serializes this instance to a JSON string. Serializes this instance to a JSON string.
""" """
return json.dumps(dataclasses.asdict(self), indent=2) return json.dumps(dataclasses.asdict(self), indent=2)
def to_sanitized_dict(self) -> Dict[str, Any]:
"""
Sanitized serialization to use with TensorBoards hparams
"""
d = dataclasses.asdict(self)
valid_types = [bool, int, float, str]
if is_torch_available():
valid_types.append(torch.Tensor)
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}