From 995006eabb4a7099d4c4273265cfd7545a9ce0b9 Mon Sep 17 00:00:00 2001 From: Davide Fiocco Date: Tue, 27 Oct 2020 19:21:54 +0100 Subject: [PATCH] Add AzureML in integrations via dedicated callback (#8062) * first attempt to add AzureML callbacks * func arg fix * var name fix, but still won't fix error... * fixing as in https://discuss.huggingface.co/t/how-to-integrate-an-azuremlcallback-for-logging-in-azure/1713/2 * Avoid lint check of azureml import * black compliance * Make isort happy * Fix point typo in docs * Add AzureML to Callbacks docs * Attempt to make sphinx happy * Format callback docs * Make documentation style happy * Make docs compliant to style Co-authored-by: Davide Fiocco --- docs/source/main_classes/callback.rst | 5 +++- src/transformers/integrations.py | 33 ++++++++++++++++++++++++++- src/transformers/trainer.py | 6 +++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst index 16b1318b71..f146244c1f 100644 --- a/docs/source/main_classes/callback.rst +++ b/docs/source/main_classes/callback.rst @@ -13,7 +13,7 @@ subclass :class:`~transformers.Trainer` and override the methods you need (see : By default a :class:`~transformers.Trainer` will use the following callbacks: - :class:`~transformers.DefaultFlowCallback` which handles the default behavior for logging, saving and evaluation. -- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the +- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProgressCallback` to display progress and print the logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise it's the second one). - :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4 @@ -21,6 +21,8 @@ By default a :class:`~transformers.Trainer` will use the following callbacks: - :class:`~transformers.integrations.WandbCallback` if `wandb `__ is installed. - :class:`~transformers.integrations.CometCallback` if `comet_ml `__ is installed. - :class:`~transformers.integrations.MLflowCallback` if `mlflow `__ is installed. +- :class:`~transformers.integrations.AzureMLCallback` if `azureml-sdk `__ is + installed. The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the :class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that @@ -50,6 +52,7 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the .. autoclass:: transformers.integrations.MLflowCallback :members: setup +.. autoclass:: transformers.integrations.AzureMLCallback TrainerCallback ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index db9a81c0fa..3e7c232ad8 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -61,6 +61,13 @@ except ImportError: except ImportError: _has_tensorboard = False +try: + from azureml.core.run import Run # noqa: F401 + + _has_azureml = True +except ImportError: + _has_azureml = False + try: import mlflow # noqa: F401 @@ -68,7 +75,6 @@ try: except ImportError: _has_mlflow = False - # No transformer imports above this point from .file_utils import is_torch_tpu_available # noqa: E402 @@ -97,6 +103,10 @@ def is_ray_available(): return _has_ray +def is_azureml_available(): + return _has_azureml + + def is_mlflow_available(): return _has_mlflow @@ -424,6 +434,27 @@ class CometCallback(TrainerCallback): experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers") +class AzureMLCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that sends the logs to `AzureML + `__. + """ + + def __init__(self, azureml_run=None): + assert _has_azureml, "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`." + self.azureml_run = azureml_run + + def on_init_end(self, args, state, control, **kwargs): + if self.azureml_run is None and state.is_world_process_zero: + self.azureml_run = Run.get_context() + + def on_log(self, args, state, control, logs=None, **kwargs): + if self.azureml_run: + for k, v in logs.items(): + if isinstance(v, (int, float)): + self.azureml_run.log(k, v, description=k) + + class MLflowCallback(TrainerCallback): """ A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow `__. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 942ddde04d..69b346d063 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -31,6 +31,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .integrations import ( # isort: split default_hp_search_backend, hp_params, + is_azureml_available, is_comet_available, is_mlflow_available, is_optuna_available, @@ -154,6 +155,11 @@ if is_optuna_available(): if is_ray_available(): from ray import tune +if is_azureml_available(): + from .integrations import AzureMLCallback + + DEFAULT_CALLBACKS.append(AzureMLCallback) + logger = logging.get_logger(__name__)