move wandb/comet logger init to train() to allow parallel logging (#6850)
* move wandb/comet logger init to train() to allow parallel logging * Setup wandb/comet loggers on first call to log()
This commit is contained in:
@@ -255,20 +255,10 @@ class Trainer:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
||||||
)
|
)
|
||||||
if is_wandb_available():
|
|
||||||
self.setup_wandb()
|
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
||||||
elif os.environ.get("WANDB_DISABLED") != "true":
|
self._loggers_initialized = False
|
||||||
logger.info(
|
|
||||||
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
|
||||||
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
|
||||||
)
|
|
||||||
if is_comet_available():
|
|
||||||
self.setup_comet()
|
|
||||||
elif os.environ.get("COMET_MODE") != "DISABLED":
|
|
||||||
logger.info(
|
|
||||||
"To use comet_ml logging, run `pip/conda install comet_ml` "
|
|
||||||
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
|
||||||
)
|
|
||||||
# Create output directory if needed
|
# Create output directory if needed
|
||||||
if self.is_world_process_zero():
|
if self.is_world_process_zero():
|
||||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||||
@@ -518,6 +508,25 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
|
||||||
|
def _setup_loggers(self):
|
||||||
|
if self._loggers_initialized:
|
||||||
|
return
|
||||||
|
if is_wandb_available():
|
||||||
|
self.setup_wandb()
|
||||||
|
elif os.environ.get("WANDB_DISABLED") != "true":
|
||||||
|
logger.info(
|
||||||
|
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
||||||
|
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
||||||
|
)
|
||||||
|
if is_comet_available():
|
||||||
|
self.setup_comet()
|
||||||
|
elif os.environ.get("COMET_MODE") != "DISABLED":
|
||||||
|
logger.info(
|
||||||
|
"To use comet_ml logging, run `pip/conda install comet_ml` "
|
||||||
|
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
||||||
|
)
|
||||||
|
self._loggers_initialized = True
|
||||||
|
|
||||||
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
||||||
""" HP search setup code """
|
""" HP search setup code """
|
||||||
if self.hp_search_backend is None or trial is None:
|
if self.hp_search_backend is None or trial is None:
|
||||||
@@ -903,6 +912,9 @@ class Trainer:
|
|||||||
iterator (:obj:`tqdm`, `optional`):
|
iterator (:obj:`tqdm`, `optional`):
|
||||||
A potential tqdm progress bar to write the logs on.
|
A potential tqdm progress bar to write the logs on.
|
||||||
"""
|
"""
|
||||||
|
# Set up loggers like W&B or Comet ML
|
||||||
|
self._setup_loggers()
|
||||||
|
|
||||||
if hasattr(self, "_log"):
|
if hasattr(self, "_log"):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
||||||
|
|||||||
Reference in New Issue
Block a user