Adds a FlyteCallback (#23759)
* initial flyte callback * lint * logs should still be saved to Flyte even if pandas isn't install (unlikely) * cr - flyte team * add docs for Flytecallback * fix doc string - cr sgugger * Apply suggestions from code review cr - sgugger fix doc strings Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -39,6 +39,7 @@ By default a [`Trainer`] will use the following callbacks:
|
|||||||
installed.
|
installed.
|
||||||
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
|
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
|
||||||
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
||||||
|
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
||||||
|
|
||||||
The main class that implements callbacks is [`TrainerCallback`]. It gets the
|
The main class that implements callbacks is [`TrainerCallback`]. It gets the
|
||||||
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
|
[`TrainingArguments`] used to instantiate the [`Trainer`], can access that
|
||||||
@@ -79,6 +80,8 @@ Here is the list of the available [`TrainerCallback`] in the library:
|
|||||||
|
|
||||||
[[autodoc]] integrations.DagsHubCallback
|
[[autodoc]] integrations.DagsHubCallback
|
||||||
|
|
||||||
|
[[autodoc]] integrations.FlyteCallback
|
||||||
|
|
||||||
## TrainerCallback
|
## TrainerCallback
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from typing import TYPE_CHECKING, Dict, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from . import __version__ as version
|
from . import __version__ as version
|
||||||
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
|
from .utils import flatten_dict, is_datasets_available, is_pandas_available, is_torch_available, logging
|
||||||
from .utils.versions import importlib_metadata
|
from .utils.versions import importlib_metadata
|
||||||
|
|
||||||
|
|
||||||
@@ -146,6 +146,16 @@ def is_codecarbon_available():
|
|||||||
return importlib.util.find_spec("codecarbon") is not None
|
return importlib.util.find_spec("codecarbon") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_flytekit_available():
|
||||||
|
return importlib.util.find_spec("flytekit") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_flyte_deck_standard_available():
|
||||||
|
if not is_flytekit_available():
|
||||||
|
return False
|
||||||
|
return importlib.util.find_spec("flytekitplugins.deck") is not None
|
||||||
|
|
||||||
|
|
||||||
def hp_params(trial):
|
def hp_params(trial):
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
@@ -1537,6 +1547,69 @@ class ClearMLCallback(TrainerCallback):
|
|||||||
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)
|
self._clearml_task.update_output_model(artifact_path, iteration=state.global_step, auto_delete_file=False)
|
||||||
|
|
||||||
|
|
||||||
|
class FlyteCallback(TrainerCallback):
|
||||||
|
"""A [`TrainerCallback`] that sends the logs to [Flyte](https://flyte.org/).
|
||||||
|
NOTE: This callback only works within a Flyte task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_log_history (`bool`, *optional*, defaults to `True`):
|
||||||
|
When set to True, the training logs are saved as a Flyte Deck.
|
||||||
|
|
||||||
|
sync_checkpoints (`bool`, *optional*, defaults to `True`):
|
||||||
|
When set to True, checkpoints are synced with Flyte and can be used to resume training in the case of an
|
||||||
|
interruption.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Note: This example skips over some setup steps for brevity.
|
||||||
|
from flytekit import current_context, task
|
||||||
|
|
||||||
|
|
||||||
|
@task
|
||||||
|
def train_hf_transformer():
|
||||||
|
cp = current_context().checkpoint
|
||||||
|
trainer = Trainer(..., callbacks=[FlyteCallback()])
|
||||||
|
output = trainer.train(resume_from_checkpoint=cp.restore())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, save_log_history: bool = True, sync_checkpoints: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
if not is_flytekit_available():
|
||||||
|
raise ImportError("FlyteCallback requires flytekit to be installed. Run `pip install flytekit`.")
|
||||||
|
|
||||||
|
if not is_flyte_deck_standard_available() or not is_pandas_available():
|
||||||
|
logger.warning(
|
||||||
|
"Syncing log history requires both flytekitplugins-deck-standard and pandas to be installed. "
|
||||||
|
"Run `pip install flytekitplugins-deck-standard pandas` to enable this feature."
|
||||||
|
)
|
||||||
|
save_log_history = False
|
||||||
|
|
||||||
|
from flytekit import current_context
|
||||||
|
|
||||||
|
self.cp = current_context().checkpoint
|
||||||
|
self.save_log_history = save_log_history
|
||||||
|
self.sync_checkpoints = sync_checkpoints
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
if self.sync_checkpoints and state.is_world_process_zero:
|
||||||
|
ckpt_dir = f"checkpoint-{state.global_step}"
|
||||||
|
artifact_path = os.path.join(args.output_dir, ckpt_dir)
|
||||||
|
|
||||||
|
logger.info(f"Syncing checkpoint in {ckpt_dir} to Flyte. This may take time.")
|
||||||
|
self.cp.save(artifact_path)
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
if self.save_log_history:
|
||||||
|
import pandas as pd
|
||||||
|
from flytekit import Deck
|
||||||
|
from flytekitplugins.deck.renderer import TableRenderer
|
||||||
|
|
||||||
|
log_history_df = pd.DataFrame(state.log_history)
|
||||||
|
Deck("Log History", TableRenderer().to_html(log_history_df))
|
||||||
|
|
||||||
|
|
||||||
INTEGRATION_TO_CALLBACK = {
|
INTEGRATION_TO_CALLBACK = {
|
||||||
"azure_ml": AzureMLCallback,
|
"azure_ml": AzureMLCallback,
|
||||||
"comet_ml": CometCallback,
|
"comet_ml": CometCallback,
|
||||||
@@ -1547,6 +1620,7 @@ INTEGRATION_TO_CALLBACK = {
|
|||||||
"codecarbon": CodeCarbonCallback,
|
"codecarbon": CodeCarbonCallback,
|
||||||
"clearml": ClearMLCallback,
|
"clearml": ClearMLCallback,
|
||||||
"dagshub": DagsHubCallback,
|
"dagshub": DagsHubCallback,
|
||||||
|
"flyte": FlyteCallback,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user