Trainer: add logging through Weights & Biases (#3916)
* feat: add logging through Weights & Biases * feat(wandb): make logging compatible with all scripts * style(trainer.py): fix formatting * [Trainer] Tweak wandb integration Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -131,6 +131,7 @@ proc_data
|
|||||||
# examples
|
# examples
|
||||||
runs
|
runs
|
||||||
/runs_old
|
/runs_old
|
||||||
|
/wandb
|
||||||
examples/runs
|
examples/runs
|
||||||
|
|
||||||
# data
|
# data
|
||||||
|
|||||||
@@ -52,6 +52,18 @@ def is_tensorboard_available():
|
|||||||
return _has_tensorboard
|
return _has_tensorboard
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
_has_wandb = True
|
||||||
|
except ImportError:
|
||||||
|
_has_wandb = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_wandb_available():
|
||||||
|
return _has_wandb
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -151,6 +163,10 @@ class Trainer:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
||||||
)
|
)
|
||||||
|
if not is_wandb_available():
|
||||||
|
logger.info(
|
||||||
|
"You are instantiating a Trainer but wandb is not installed. Install it to use Weights & Biases logging."
|
||||||
|
)
|
||||||
set_seed(self.args.seed)
|
set_seed(self.args.seed)
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if self.args.local_rank in [-1, 0]:
|
if self.args.local_rank in [-1, 0]:
|
||||||
@@ -209,6 +225,12 @@ class Trainer:
|
|||||||
)
|
)
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
def _setup_wandb(self):
|
||||||
|
# Start a wandb run and log config parameters
|
||||||
|
wandb.init(name=self.args.logging_dir, config=vars(self.args))
|
||||||
|
# keep track of model topology and gradients
|
||||||
|
# wandb.watch(self.model)
|
||||||
|
|
||||||
def train(self, model_path: Optional[str] = None):
|
def train(self, model_path: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Main training entry point.
|
Main training entry point.
|
||||||
@@ -263,6 +285,9 @@ class Trainer:
|
|||||||
|
|
||||||
if self.tb_writer is not None:
|
if self.tb_writer is not None:
|
||||||
self.tb_writer.add_text("args", self.args.to_json_string())
|
self.tb_writer.add_text("args", self.args.to_json_string())
|
||||||
|
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
||||||
|
if is_wandb_available():
|
||||||
|
self._setup_wandb()
|
||||||
|
|
||||||
# Train!
|
# Train!
|
||||||
logger.info("***** Running training *****")
|
logger.info("***** Running training *****")
|
||||||
@@ -351,6 +376,9 @@ class Trainer:
|
|||||||
if self.tb_writer:
|
if self.tb_writer:
|
||||||
for k, v in logs.items():
|
for k, v in logs.items():
|
||||||
self.tb_writer.add_scalar(k, v, global_step)
|
self.tb_writer.add_scalar(k, v, global_step)
|
||||||
|
if is_wandb_available():
|
||||||
|
wandb.log(logs, step=global_step)
|
||||||
|
|
||||||
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
|
epoch_iterator.write(json.dumps({**logs, **{"step": global_step}}))
|
||||||
|
|
||||||
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
|
if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
|
||||||
@@ -467,7 +495,7 @@ class Trainer:
|
|||||||
shutil.rmtree(checkpoint)
|
shutil.rmtree(checkpoint)
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None
|
self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Run evaluation and return metrics.
|
Run evaluation and return metrics.
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
from .file_utils import cached_property, is_torch_available, torch_required
|
from .file_utils import cached_property, is_torch_available, torch_required
|
||||||
|
|
||||||
@@ -138,3 +138,13 @@ class TrainingArguments:
|
|||||||
Serializes this instance to a JSON string.
|
Serializes this instance to a JSON string.
|
||||||
"""
|
"""
|
||||||
return json.dumps(dataclasses.asdict(self), indent=2)
|
return json.dumps(dataclasses.asdict(self), indent=2)
|
||||||
|
|
||||||
|
def to_sanitized_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Sanitized serialization to use with TensorBoard’s hparams
|
||||||
|
"""
|
||||||
|
d = dataclasses.asdict(self)
|
||||||
|
valid_types = [bool, int, float, str]
|
||||||
|
if is_torch_available():
|
||||||
|
valid_types.append(torch.Tensor)
|
||||||
|
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
|
||||||
|
|||||||
Reference in New Issue
Block a user