Add report_to training arguments to control the reporting integrations used (#9735)
This commit is contained in:
@@ -225,6 +225,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
|
|||||||
return best_run
|
return best_run
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_reporting_integrations():
|
||||||
|
integrations = []
|
||||||
|
if is_azureml_available():
|
||||||
|
integrations.append("azure_ml")
|
||||||
|
if is_comet_available():
|
||||||
|
integrations.append("comet_ml")
|
||||||
|
if is_mlflow_available():
|
||||||
|
integrations.append("mlflow")
|
||||||
|
if is_tensorboard_available():
|
||||||
|
integrations.append("tensorboard")
|
||||||
|
if is_wandb_available():
|
||||||
|
integrations.append("wandb")
|
||||||
|
return integrations
|
||||||
|
|
||||||
|
|
||||||
def rewrite_logs(d):
|
def rewrite_logs(d):
|
||||||
new_d = {}
|
new_d = {}
|
||||||
eval_prefix = "eval_"
|
eval_prefix = "eval_"
|
||||||
@@ -757,3 +772,21 @@ class MLflowCallback(TrainerCallback):
|
|||||||
# not let you start a new run before the previous one is killed
|
# not let you start a new run before the previous one is killed
|
||||||
if self._ml_flow.active_run is not None:
|
if self._ml_flow.active_run is not None:
|
||||||
self._ml_flow.end_run(status="KILLED")
|
self._ml_flow.end_run(status="KILLED")
|
||||||
|
|
||||||
|
|
||||||
|
INTEGRATION_TO_CALLBACK = {
|
||||||
|
"azure_ml": AzureMLCallback,
|
||||||
|
"comet_ml": CometCallback,
|
||||||
|
"mlflow": MLflowCallback,
|
||||||
|
"tensorboard": TensorBoardCallback,
|
||||||
|
"wandb": WandbCallback,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_reporting_integration_callbacks(report_to):
|
||||||
|
for integration in report_to:
|
||||||
|
if integration not in INTEGRATION_TO_CALLBACK:
|
||||||
|
raise ValueError(
|
||||||
|
f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
|
||||||
|
)
|
||||||
|
return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]
|
||||||
|
|||||||
@@ -31,15 +31,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Un
|
|||||||
# Integrations must be imported before ML frameworks:
|
# Integrations must be imported before ML frameworks:
|
||||||
from .integrations import ( # isort: split
|
from .integrations import ( # isort: split
|
||||||
default_hp_search_backend,
|
default_hp_search_backend,
|
||||||
|
get_reporting_integration_callbacks,
|
||||||
hp_params,
|
hp_params,
|
||||||
is_azureml_available,
|
|
||||||
is_comet_available,
|
|
||||||
is_fairscale_available,
|
is_fairscale_available,
|
||||||
is_mlflow_available,
|
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_tune_available,
|
is_ray_tune_available,
|
||||||
is_tensorboard_available,
|
|
||||||
is_wandb_available,
|
|
||||||
run_hp_search_optuna,
|
run_hp_search_optuna,
|
||||||
run_hp_search_ray,
|
run_hp_search_ray,
|
||||||
init_deepspeed,
|
init_deepspeed,
|
||||||
@@ -124,32 +120,6 @@ if is_torch_tpu_available():
|
|||||||
import torch_xla.debug.metrics as met
|
import torch_xla.debug.metrics as met
|
||||||
import torch_xla.distributed.parallel_loader as pl
|
import torch_xla.distributed.parallel_loader as pl
|
||||||
|
|
||||||
if is_tensorboard_available():
|
|
||||||
from .integrations import TensorBoardCallback
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS.append(TensorBoardCallback)
|
|
||||||
|
|
||||||
|
|
||||||
if is_wandb_available():
|
|
||||||
from .integrations import WandbCallback
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS.append(WandbCallback)
|
|
||||||
|
|
||||||
if is_comet_available():
|
|
||||||
from .integrations import CometCallback
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS.append(CometCallback)
|
|
||||||
|
|
||||||
if is_mlflow_available():
|
|
||||||
from .integrations import MLflowCallback
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS.append(MLflowCallback)
|
|
||||||
|
|
||||||
if is_azureml_available():
|
|
||||||
from .integrations import AzureMLCallback
|
|
||||||
|
|
||||||
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
|
||||||
|
|
||||||
if is_fairscale_available():
|
if is_fairscale_available():
|
||||||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
|
||||||
from fairscale.optim import OSS
|
from fairscale.optim import OSS
|
||||||
@@ -300,7 +270,8 @@ class Trainer:
|
|||||||
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
||||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||||
)
|
)
|
||||||
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
|
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||||
|
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||||
self.callback_handler = CallbackHandler(
|
self.callback_handler = CallbackHandler(
|
||||||
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -231,6 +231,9 @@ class TrainingArguments:
|
|||||||
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
group_by_length (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
|
Whether or not to group together samples of roughly the same legnth in the training dataset (to minimize
|
||||||
padding applied and be more efficient). Only useful if applying dynamic padding.
|
padding applied and be more efficient). Only useful if applying dynamic padding.
|
||||||
|
report_to (:obj:`List[str]`, `optional`, defaults to the list of integrations platforms installed):
|
||||||
|
The list of integrations to report the results and logs to. Supported platforms are :obj:`"azure_ml"`,
|
||||||
|
:obj:`"comet_ml"`, :obj:`"mlflow"`, :obj:`"tensorboard"` and :obj:`"wandb"`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -413,6 +416,9 @@ class TrainingArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
|
metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
|
||||||
)
|
)
|
||||||
|
report_to: Optional[List[str]] = field(
|
||||||
|
default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
||||||
|
)
|
||||||
_n_gpu: int = field(init=False, repr=False, default=-1)
|
_n_gpu: int = field(init=False, repr=False, default=-1)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -434,6 +440,11 @@ class TrainingArguments:
|
|||||||
|
|
||||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||||
|
if self.report_to is None:
|
||||||
|
# Import at runtime to avoid a circular import.
|
||||||
|
from .integrations import get_available_reporting_integrations
|
||||||
|
|
||||||
|
self.report_to = get_available_reporting_integrations()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||||
|
|||||||
Reference in New Issue
Block a user