From 62ba64b90a2786c74e202d9f1ad2e34da4e083ce Mon Sep 17 00:00:00 2001 From: peridotml <106936600+peridotml@users.noreply.github.com> Date: Tue, 30 May 2023 07:08:07 -0700 Subject: [PATCH] Adds a FlyteCallback (#23759) * initial flyte callback * lint * logs should still be saved to Flyte even if pandas isn't install (unlikely) * cr - flyte team * add docs for Flytecallback * fix doc string - cr sgugger * Apply suggestions from code review cr - sgugger fix doc strings Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/main_classes/callback.mdx | 3 + src/transformers/integrations.py | 76 +++++++++++++++++++++++- 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/docs/source/en/main_classes/callback.mdx b/docs/source/en/main_classes/callback.mdx index 33ae17c66d..2636130473 100644 --- a/docs/source/en/main_classes/callback.mdx +++ b/docs/source/en/main_classes/callback.mdx @@ -39,6 +39,7 @@ By default a [`Trainer`] will use the following callbacks: installed. - [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed. - [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed. +- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed. The main class that implements callbacks is [`TrainerCallback`]. It gets the [`TrainingArguments`] used to instantiate the [`Trainer`], can access that @@ -79,6 +80,8 @@ Here is the list of the available [`TrainerCallback`] in the library: [[autodoc]] integrations.DagsHubCallback +[[autodoc]] integrations.FlyteCallback + ## TrainerCallback [[autodoc]] TrainerCallback diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 4b0e1c590d..0a85cef698 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -30,7 +30,7 @@ from typing import TYPE_CHECKING, Dict, Optional import numpy as np from . import __version__ as version -from .utils import flatten_dict, is_datasets_available, is_torch_available, logging +from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging from .utils.versions import importlib_metadata @@ -146,6 +146,16 @@ def is_codecarbon_available(): return importlib.util.find_spec("codecarbon") is not None +def is_flytekit_available(): + return importlib.util.find_spec("flytekit") is not None + + +def is_flyte_deck_standard_available(): + if not is_flytekit_available(): + return False + return importlib.util.find_spec("flytekitplugins.deck") is not None + + def hp_params(trial): if is_optuna_available(): import optuna @@ -1537,6 +1547,69 @@ class ClearMLCallback(TrainerCallback): self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False) +class FlyteCallback(TrainerCallback): + """A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/). + NOTE: This callback only works within a Flyte task. + + Args: + save_log_history (`bool`, *optional*, defaults to `True`): + When set to True, the training logs are saved as a Flyte Deck. + + sync_checkpoints (`bool`, *optional*, defaults to `True`): + When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an + interruption. + + Example: + + ```python + # Note: This example skips over some setup steps for brevity. + from flytekit import current_context, task + + + @task + def train_hf_transformer(): + cp = current_context().checkpoint + trainer = Trainer(..., callbacks=[FlyteCallback()]) + output = trainer.train(resume_from_checkpoint=cp.restore()) + ``` + """ + + def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True): + super().__init__() + if not is_flytekit_available(): + raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.") + + if not is_flyte_deck_standard_available() or not is_pandas_available(): + logger.warning( + "Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. " + "Run `pip install flytekitplugins-deck-standard pandas` to enable this feature." + ) + save_log_history = False + + from flytekit import current_context + + self.cp = current_context().checkpoint + self.save_log_history = save_log_history + self.sync_checkpoints = sync_checkpoints + + def on_save(self, args, state, control, **kwargs): + if self.sync_checkpoints and state.is_world_process_zero: + ckpt_dir = f"checkpoint-{state.global_step}" + artifact_path = os.path.join(args.output_dir, ckpt_dir) + + logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.") + self.cp.save(artifact_path) + + def on_train_end(self, args, state, control, **kwargs): + if self.save_log_history: + import pandas as pd + from flytekit import Deck + from flytekitplugins.deck.renderer import TableRenderer + + log_history_df = pd.DataFrame(state.log_history) + Deck("Log History", TableRenderer().to_html(log_history_df)) + + INTEGRATION_TO_CALLBACK = { "azure_ml": AzureMLCallback, "comet_ml": CometCallback, @@ -1547,6 +1620,7 @@ INTEGRATION_TO_CALLBACK = { "codecarbon": CodeCarbonCallback, "clearml": ClearMLCallback, "dagshub": DagsHubCallback, + "flyte": FlyteCallback, }