Add report_to training arguments to control the reporting integrations used (#9735)
This commit is contained in:
@@ -225,6 +225,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
||||
return best_run
|
||||
|
||||
|
||||
def get_available_reporting_integrations():
|
||||
integrations = []
|
||||
if is_azureml_available():
|
||||
integrations.append("azure_ml")
|
||||
if is_comet_available():
|
||||
integrations.append("comet_ml")
|
||||
if is_mlflow_available():
|
||||
integrations.append("mlflow")
|
||||
if is_tensorboard_available():
|
||||
integrations.append("tensorboard")
|
||||
if is_wandb_available():
|
||||
integrations.append("wandb")
|
||||
return integrations
|
||||
|
||||
|
||||
def rewrite_logs(d):
|
||||
new_d = {}
|
||||
eval_prefix = "eval_"
|
||||
@@ -757,3 +772,21 @@ class MLflowCallback(TrainerCallback):
|
||||
# not let you start a new run before the previous one is killed
|
||||
if self._ml_flow.active_run is not None:
|
||||
self._ml_flow.end_run(status="KILLED")
|
||||
|
||||
|
||||
INTEGRATION_TO_CALLBACK = {
|
||||
"azure_ml": AzureMLCallback,
|
||||
"comet_ml": CometCallback,
|
||||
"mlflow": MLflowCallback,
|
||||
"tensorboard": TensorBoardCallback,
|
||||
"wandb": WandbCallback,
|
||||
}
|
||||
|
||||
|
||||
def get_reporting_integration_callbacks(report_to):
|
||||
for integration in report_to:
|
||||
if integration not in INTEGRATION_TO_CALLBACK:
|
||||
raise ValueError(
|
||||
f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
|
||||
)
|
||||
return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]
|
||||
|
||||
@@ -31,15 +31,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
||||
# Integrations must be imported before ML frameworks:
|
||||
from .integrations import ( # isort: split
|
||||
default_hp_search_backend,
|
||||
get_reporting_integration_callbacks,
|
||||
hp_params,
|
||||
is_azureml_available,
|
||||
is_comet_available,
|
||||
is_fairscale_available,
|
||||
is_mlflow_available,
|
||||
is_optuna_available,
|
||||
is_ray_tune_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
run_hp_search_optuna,
|
||||
run_hp_search_ray,
|
||||
init_deepspeed,
|
||||
@@ -124,32 +120,6 @@ if is_torch_tpu_available():
|
||||
import torch_xla.debug.metrics as met
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
|
||||
if is_tensorboard_available():
|
||||
from .integrations import TensorBoardCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(TensorBoardCallback)
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
from .integrations import WandbCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(WandbCallback)
|
||||
|
||||
if is_comet_available():
|
||||
from .integrations import CometCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(CometCallback)
|
||||
|
||||
if is_mlflow_available():
|
||||
from .integrations import MLflowCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(MLflowCallback)
|
||||
|
||||
if is_azureml_available():
|
||||
from .integrations import AzureMLCallback
|
||||
|
||||
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
||||
|
||||
if is_fairscale_available():
|
||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||
from fairscale.optim import OSS
|
||||
@@ -300,7 +270,8 @@ class Trainer:
|
||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
||||
@@ -231,6 +231,9 @@ class TrainingArguments:
|
||||
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
|
||||
padding applied and be more efficient). Only useful if applying dynamic padding.
|
||||
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
|
||||
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
|
||||
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
@@ -413,6 +416,9 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
|
||||
)
|
||||
report_to: Optional[List[str]] = field(
|
||||
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
||||
)
|
||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -434,6 +440,11 @@ class TrainingArguments:
|
||||
|
||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||
if self.report_to is None:
|
||||
# Import at runtime to avoid a circular import.
|
||||
from .integrations import get_available_reporting_integrations
|
||||
|
||||
self.report_to = get_available_reporting_integrations()
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
|
||||
Reference in New Issue
Block a user