From eba418ac5df71d08927efb7e3b738833998162ff Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Tue, 26 Jan 2021 19:21:26 +0100 Subject: [PATCH] Commit the last step on world_process_zero in WandbCallback (#9805) * Commit the last step on world_process_zero in WandbCallback * Use the environment variable WANDB_LOG_MODEL as a default value in WandbCallback --- src/transformers/integrations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 6a49635e6c..3682b8bd04 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -516,6 +516,8 @@ class WandbCallback(TrainerCallback): else: self._wandb = wandb self._initialized = False + # log outputs + self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) def setup(self, args, state, model, reinit, **kwargs): """ @@ -569,9 +571,6 @@ class WandbCallback(TrainerCallback): model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps) ) - # log outputs - self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) - def on_train_begin(self, args, state, control, model=None, **kwargs): if self._wandb is None: return @@ -583,7 +582,8 @@ class WandbCallback(TrainerCallback): if self._wandb is None: return # commit last step - self._wandb.log({}) + if state.is_world_process_zero: + self._wandb.log({}) if self._log_model and self._initialized and state.is_world_process_zero: from .trainer import Trainer