Add AzureML in integrations via dedicated callback (#8062)
* first attempt to add AzureML callbacks * func arg fix * var name fix, but still won't fix error... * fixing as in https://discuss.huggingface.co/t/how-to-integrate-an-azuremlcallback-for-logging-in-azure/1713/2 * Avoid lint check of azureml import * black compliance * Make isort happy * Fix point typo in docs * Add AzureML to Callbacks docs * Attempt to make sphinx happy * Format callback docs * Make documentation style happy * Make docs compliant to style Co-authored-by: Davide Fiocco <davide.fiocco@frontiersin.net>
This commit is contained in:
@@ -13,7 +13,7 @@ subclass :class:`~transformers.Trainer` and override the methods you need (see :
|
|||||||
By default a :class:`~transformers.Trainer` will use the following callbacks:
|
By default a :class:`~transformers.Trainer` will use the following callbacks:
|
||||||
|
|
||||||
- :class:`~transformers.DefaultFlowCallback` which handles the default behavior for logging, saving and evaluation.
|
- :class:`~transformers.DefaultFlowCallback` which handles the default behavior for logging, saving and evaluation.
|
||||||
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
|
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProgressCallback` to display progress and print the
|
||||||
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
|
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
|
||||||
it's the second one).
|
it's the second one).
|
||||||
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
|
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
|
||||||
@@ -21,6 +21,8 @@ By default a :class:`~transformers.Trainer` will use the following callbacks:
|
|||||||
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
|
||||||
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.
|
||||||
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.
|
- :class:`~transformers.integrations.MLflowCallback` if `mlflow <https://www.mlflow.org/>`__ is installed.
|
||||||
|
- :class:`~transformers.integrations.AzureMLCallback` if `azureml-sdk <https://pypi.org/project/azureml-sdk/>`__ is
|
||||||
|
installed.
|
||||||
|
|
||||||
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
|
||||||
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
|
||||||
@@ -50,6 +52,7 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
|
|||||||
.. autoclass:: transformers.integrations.MLflowCallback
|
.. autoclass:: transformers.integrations.MLflowCallback
|
||||||
:members: setup
|
:members: setup
|
||||||
|
|
||||||
|
.. autoclass:: transformers.integrations.AzureMLCallback
|
||||||
|
|
||||||
TrainerCallback
|
TrainerCallback
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ except ImportError:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_has_tensorboard = False
|
_has_tensorboard = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from azureml.core.run import Run # noqa: F401
|
||||||
|
|
||||||
|
_has_azureml = True
|
||||||
|
except ImportError:
|
||||||
|
_has_azureml = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import mlflow # noqa: F401
|
import mlflow # noqa: F401
|
||||||
|
|
||||||
@@ -68,7 +75,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
_has_mlflow = False
|
_has_mlflow = False
|
||||||
|
|
||||||
|
|
||||||
# No transformer imports above this point
|
# No transformer imports above this point
|
||||||
|
|
||||||
from .file_utils import is_torch_tpu_available # noqa: E402
|
from .file_utils import is_torch_tpu_available # noqa: E402
|
||||||
@@ -97,6 +103,10 @@ def is_ray_available():
|
|||||||
return _has_ray
|
return _has_ray
|
||||||
|
|
||||||
|
|
||||||
|
def is_azureml_available():
|
||||||
|
return _has_azureml
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
def is_mlflow_available():
|
||||||
return _has_mlflow
|
return _has_mlflow
|
||||||
|
|
||||||
@@ -424,6 +434,27 @@ class CometCallback(TrainerCallback):
|
|||||||
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
|
experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
|
||||||
|
|
||||||
|
|
||||||
|
class AzureMLCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A :class:`~transformers.TrainerCallback` that sends the logs to `AzureML
|
||||||
|
<https://pypi.org/project/azureml-sdk/>`__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, azureml_run=None):
|
||||||
|
assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
|
||||||
|
self.azureml_run = azureml_run
|
||||||
|
|
||||||
|
def on_init_end(self, args, state, control, **kwargs):
|
||||||
|
if self.azureml_run is None and state.is_world_process_zero:
|
||||||
|
self.azureml_run = Run.get_context()
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||||
|
if self.azureml_run:
|
||||||
|
for k, v in logs.items():
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
self.azureml_run.log(k, v, description=k)
|
||||||
|
|
||||||
|
|
||||||
class MLflowCallback(TrainerCallback):
|
class MLflowCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
|
A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|||||||
from .integrations import ( # isort: split
|
from .integrations import ( # isort: split
|
||||||
default_hp_search_backend,
|
default_hp_search_backend,
|
||||||
hp_params,
|
hp_params,
|
||||||
|
is_azureml_available,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
@@ -154,6 +155,11 @@ if is_optuna_available():
|
|||||||
if is_ray_available():
|
if is_ray_available():
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
|
||||||
|
if is_azureml_available():
|
||||||
|
from .integrations import AzureMLCallback
|
||||||
|
|
||||||
|
DEFAULT_CALLBACKS.append(AzureMLCallback)
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user