From bd0eab351a338175053998ddfc059f1cb6424ab4 Mon Sep 17 00:00:00 2001 From: Teven Date: Wed, 5 Aug 2020 15:05:52 +0200 Subject: [PATCH] Trainer + wandb quality of life logging tweaks (#6241) * added `name` argument for wandb logging, also logging model config with trainer arguments * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * added tf, post-review changes Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/trainer.py | 5 ++++- src/transformers/trainer_tf.py | 3 ++- src/transformers/training_args.py | 6 ++++++ src/transformers/training_args_tf.py | 2 ++ 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e1429713fb..10674c0620 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -383,7 +383,10 @@ class Trainer: logger.info( 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' ) - wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=self.args.to_sanitized_dict()) + combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} + wandb.init( + project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name + ) # keep track of model topology and gradients, unsupported on TPU if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": wandb.watch( diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index aaca022e81..d388017437 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -215,7 +215,8 @@ class TFTrainer: return self._setup_wandb() logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') - wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) + combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} + wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name) def prediction_loop( self, diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ad33266a81..713293e349 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -109,6 +109,8 @@ class TrainingArguments: make use of the past hidden states for their predictions. If this argument is set to a positive int, the ``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model at the next training step under the keyword argument ``mems``. + run_name (:obj:`str`, `optional`): + A descriptor for the run. Notably used for wandb logging. """ output_dir: str = field( @@ -222,6 +224,10 @@ class TrainingArguments: metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, ) + run_name: Optional[str] = field( + default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."} + ) + @property def train_batch_size(self) -> int: """ diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 0adf344645..c1dea84d58 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -95,6 +95,8 @@ class TFTrainingArguments(TrainingArguments): at the next training step under the keyword argument ``mems``. tpu_name (:obj:`str`, `optional`): The name of the TPU the process is running on. + run_name (:obj:`str`, `optional`): + A descriptor for the run. Notably used for wandb logging. """ tpu_name: str = field(