Commit the last step on world_process_zero in WandbCallback (#9805)
* Commit the last step on world_process_zero in WandbCallback * Use the environment variable WANDB_LOG_MODEL as a default value in WandbCallback
This commit is contained in:
@@ -516,6 +516,8 @@ class WandbCallback(TrainerCallback):
|
|||||||
else:
|
else:
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
# log outputs
|
||||||
|
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
||||||
|
|
||||||
def setup(self, args, state, model, reinit, **kwargs):
|
def setup(self, args, state, model, reinit, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -569,9 +571,6 @@ class WandbCallback(TrainerCallback):
|
|||||||
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
|
model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
|
||||||
)
|
)
|
||||||
|
|
||||||
# log outputs
|
|
||||||
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
|
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
if self._wandb is None:
|
if self._wandb is None:
|
||||||
return
|
return
|
||||||
@@ -583,6 +582,7 @@ class WandbCallback(TrainerCallback):
|
|||||||
if self._wandb is None:
|
if self._wandb is None:
|
||||||
return
|
return
|
||||||
# commit last step
|
# commit last step
|
||||||
|
if state.is_world_process_zero:
|
||||||
self._wandb.log({})
|
self._wandb.log({})
|
||||||
if self._log_model and self._initialized and state.is_world_process_zero:
|
if self._log_model and self._initialized and state.is_world_process_zero:
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
|
|||||||
Reference in New Issue
Block a user