From cedc547e7e009e4745db350505848fd5c4f8f6f3 Mon Sep 17 00:00:00 2001 From: Jay Mody Date: Mon, 3 Aug 2020 09:00:39 -0400 Subject: [PATCH] Adds train_batch_size, eval_batch_size, and n_gpu to to_sanitized_dict output for logging. (#5331) * Adds train_batch_size, eval_batch_size, and n_gpu to to_sanitized_dict() output * Update wandb config logging to use to_sanitized_dict * removed n_gpu from sanitized dict * fix quality check errors --- src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8a3209355a..e1429713fb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -383,7 +383,7 @@ class Trainer: logger.info( 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' ) - wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) + wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=self.args.to_sanitized_dict()) # keep track of model topology and gradients, unsupported on TPU if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": wandb.watch( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e6506d9763..ad33266a81 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -310,7 +310,10 @@ class TrainingArguments: Sanitized serialization to use with TensorBoard’s hparams """ d = dataclasses.asdict(self) + d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}} + valid_types = [bool, int, float, str] if is_torch_available(): valid_types.append(torch.Tensor) + return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}