Adds dvclive callback (#27352)
* dvclive trainer callback * style fixes * dvclive link fixes
This commit is contained in:
@@ -44,6 +44,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
|
|||||||
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
|
- [`~integrations.ClearMLCallback`] if [clearml](https://github.com/allegroai/clearml) is installed.
|
||||||
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
||||||
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
||||||
|
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
|
||||||
|
|
||||||
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
|
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
|
||||||
|
|
||||||
@@ -88,6 +89,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
|
|||||||
|
|
||||||
[[autodoc]] integrations.FlyteCallback
|
[[autodoc]] integrations.FlyteCallback
|
||||||
|
|
||||||
|
[[autodoc]] integrations.DVCLiveCallback
|
||||||
|
- setup
|
||||||
|
|
||||||
## TrainerCallback
|
## TrainerCallback
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
|
|||||||
- [`~integrations.ClearMLCallback`] [clearml](https://github.com/allegroai/clearml) がインストールされている場合。
|
- [`~integrations.ClearMLCallback`] [clearml](https://github.com/allegroai/clearml) がインストールされている場合。
|
||||||
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
|
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
|
||||||
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
|
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
|
||||||
|
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
|
||||||
|
|
||||||
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
|
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
|
||||||
|
|
||||||
@@ -88,6 +89,9 @@ rendered properly in your Markdown viewer.
|
|||||||
|
|
||||||
[[autodoc]] integrations.FlyteCallback
|
[[autodoc]] integrations.FlyteCallback
|
||||||
|
|
||||||
|
[[autodoc]] integrations.DVCLiveCallback
|
||||||
|
- setup
|
||||||
|
|
||||||
## TrainerCallback
|
## TrainerCallback
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ You can easily log and monitor your runs code. The following are currently suppo
|
|||||||
* [Comet ML](https://www.comet.ml/docs/python-sdk/huggingface/)
|
* [Comet ML](https://www.comet.ml/docs/python-sdk/huggingface/)
|
||||||
* [Neptune](https://docs.neptune.ai/integrations-and-supported-tools/model-training/hugging-face)
|
* [Neptune](https://docs.neptune.ai/integrations-and-supported-tools/model-training/hugging-face)
|
||||||
* [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps)
|
* [ClearML](https://clear.ml/docs/latest/docs/getting_started/ds/ds_first_steps)
|
||||||
|
* [DVCLive](https://dvc.org/doc/dvclive/ml-frameworks/huggingface)
|
||||||
|
|
||||||
### Weights & Biases
|
### Weights & Biases
|
||||||
|
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ _import_structure = {
|
|||||||
"integrations": [
|
"integrations": [
|
||||||
"is_clearml_available",
|
"is_clearml_available",
|
||||||
"is_comet_available",
|
"is_comet_available",
|
||||||
|
"is_dvclive_available",
|
||||||
"is_neptune_available",
|
"is_neptune_available",
|
||||||
"is_optuna_available",
|
"is_optuna_available",
|
||||||
"is_ray_available",
|
"is_ray_available",
|
||||||
@@ -4300,6 +4301,7 @@ if TYPE_CHECKING:
|
|||||||
from .integrations import (
|
from .integrations import (
|
||||||
is_clearml_available,
|
is_clearml_available,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
|
is_dvclive_available,
|
||||||
is_neptune_available,
|
is_neptune_available,
|
||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_available,
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ _import_structure = {
|
|||||||
"CodeCarbonCallback",
|
"CodeCarbonCallback",
|
||||||
"CometCallback",
|
"CometCallback",
|
||||||
"DagsHubCallback",
|
"DagsHubCallback",
|
||||||
|
"DVCLiveCallback",
|
||||||
"FlyteCallback",
|
"FlyteCallback",
|
||||||
"MLflowCallback",
|
"MLflowCallback",
|
||||||
"NeptuneCallback",
|
"NeptuneCallback",
|
||||||
@@ -58,6 +59,7 @@ _import_structure = {
|
|||||||
"is_codecarbon_available",
|
"is_codecarbon_available",
|
||||||
"is_comet_available",
|
"is_comet_available",
|
||||||
"is_dagshub_available",
|
"is_dagshub_available",
|
||||||
|
"is_dvclive_available",
|
||||||
"is_flyte_deck_standard_available",
|
"is_flyte_deck_standard_available",
|
||||||
"is_flytekit_available",
|
"is_flytekit_available",
|
||||||
"is_mlflow_available",
|
"is_mlflow_available",
|
||||||
@@ -105,6 +107,7 @@ if TYPE_CHECKING:
|
|||||||
CodeCarbonCallback,
|
CodeCarbonCallback,
|
||||||
CometCallback,
|
CometCallback,
|
||||||
DagsHubCallback,
|
DagsHubCallback,
|
||||||
|
DVCLiveCallback,
|
||||||
FlyteCallback,
|
FlyteCallback,
|
||||||
MLflowCallback,
|
MLflowCallback,
|
||||||
NeptuneCallback,
|
NeptuneCallback,
|
||||||
@@ -119,6 +122,7 @@ if TYPE_CHECKING:
|
|||||||
is_codecarbon_available,
|
is_codecarbon_available,
|
||||||
is_comet_available,
|
is_comet_available,
|
||||||
is_dagshub_available,
|
is_dagshub_available,
|
||||||
|
is_dvclive_available,
|
||||||
is_flyte_deck_standard_available,
|
is_flyte_deck_standard_available,
|
||||||
is_flytekit_available,
|
is_flytekit_available,
|
||||||
is_mlflow_available,
|
is_mlflow_available,
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import sys
|
|||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -152,6 +152,10 @@ def is_flyte_deck_standard_available():
|
|||||||
return importlib.util.find_spec("flytekitplugins.deck") is not None
|
return importlib.util.find_spec("flytekitplugins.deck") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_dvclive_available():
|
||||||
|
return importlib.util.find_spec("dvclive") is not None
|
||||||
|
|
||||||
|
|
||||||
def hp_params(trial):
|
def hp_params(trial):
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
@@ -541,6 +545,8 @@ def get_available_reporting_integrations():
|
|||||||
integrations.append("comet_ml")
|
integrations.append("comet_ml")
|
||||||
if is_dagshub_available():
|
if is_dagshub_available():
|
||||||
integrations.append("dagshub")
|
integrations.append("dagshub")
|
||||||
|
if is_dvclive_available():
|
||||||
|
integrations.append("dvclive")
|
||||||
if is_mlflow_available():
|
if is_mlflow_available():
|
||||||
integrations.append("mlflow")
|
integrations.append("mlflow")
|
||||||
if is_neptune_available():
|
if is_neptune_available():
|
||||||
@@ -1605,6 +1611,98 @@ class FlyteCallback(TrainerCallback):
|
|||||||
Deck("Log History", TableRenderer().to_html(log_history_df))
|
Deck("Log History", TableRenderer().to_html(log_history_df))
|
||||||
|
|
||||||
|
|
||||||
|
class DVCLiveCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A [`TrainerCallback`] that sends the logs to [DVCLive](https://www.dvc.org/doc/dvclive).
|
||||||
|
|
||||||
|
Use the environment variables below in `setup` to configure the integration. To customize this callback beyond
|
||||||
|
those environment variables, see [here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
live (`dvclive.Live`, *optional*, defaults to `None`):
|
||||||
|
Optional Live instance. If None, a new instance will be created using **kwargs.
|
||||||
|
log_model (Union[Literal["all"], bool], *optional*, defaults to `None`):
|
||||||
|
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True`,
|
||||||
|
the final checkpoint is logged at the end of training. If set to `"all"`, the entire
|
||||||
|
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
live: Optional[Any] = None,
|
||||||
|
log_model: Optional[Union[Literal["all"], bool]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if not is_dvclive_available():
|
||||||
|
raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.")
|
||||||
|
from dvclive import Live
|
||||||
|
|
||||||
|
self._log_model = log_model
|
||||||
|
|
||||||
|
self._initialized = False
|
||||||
|
self.live = None
|
||||||
|
if isinstance(live, Live):
|
||||||
|
self.live = live
|
||||||
|
self._initialized = True
|
||||||
|
elif live is not None:
|
||||||
|
raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live")
|
||||||
|
|
||||||
|
def setup(self, args, state, model):
|
||||||
|
"""
|
||||||
|
Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see
|
||||||
|
[here](https://dvc.org/doc/dvclive/ml-frameworks/huggingface).
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
- **HF_DVCLIVE_LOG_MODEL** (`str`, *optional*):
|
||||||
|
Whether to use `dvclive.Live.log_artifact()` to log checkpoints created by [`Trainer`]. If set to `True` or
|
||||||
|
*1*, the final checkpoint is logged at the end of training. If set to `all`, the entire
|
||||||
|
[`TrainingArguments`]'s `output_dir` is logged at each checkpoint.
|
||||||
|
"""
|
||||||
|
from dvclive import Live
|
||||||
|
|
||||||
|
self._initalized = True
|
||||||
|
if self._log_model is not None:
|
||||||
|
log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL")
|
||||||
|
if log_model_env.upper() in ENV_VARS_TRUE_VALUES:
|
||||||
|
self._log_model = True
|
||||||
|
elif log_model_env.lower() == "all":
|
||||||
|
self._log_model = "all"
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
if not self.live:
|
||||||
|
self.live = Live()
|
||||||
|
self.live.log_params(args.to_dict())
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
from dvclive.utils import standardize_metric_name
|
||||||
|
|
||||||
|
for key, value in logs.items():
|
||||||
|
self.live.log_metric(standardize_metric_name(key, "dvclive.huggingface"), value)
|
||||||
|
self.live.next_step()
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
if self._log_model == "all" and self._initialized and state.is_world_process_zero:
|
||||||
|
self.live.log_artifact(args.output_dir)
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, **kwargs):
|
||||||
|
if self._initialized and state.is_world_process_zero:
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
if self._log_model is True:
|
||||||
|
fake_trainer = Trainer(args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer"))
|
||||||
|
name = "best" if args.load_best_model_at_end else "last"
|
||||||
|
output_dir = os.path.join(args.output_dir, name)
|
||||||
|
fake_trainer.save_model(output_dir)
|
||||||
|
self.live.log_artifact(output_dir, name=name, type="model", copy=True)
|
||||||
|
self.live.end()
|
||||||
|
|
||||||
|
|
||||||
INTEGRATION_TO_CALLBACK = {
|
INTEGRATION_TO_CALLBACK = {
|
||||||
"azure_ml": AzureMLCallback,
|
"azure_ml": AzureMLCallback,
|
||||||
"comet_ml": CometCallback,
|
"comet_ml": CometCallback,
|
||||||
@@ -1616,6 +1714,7 @@ INTEGRATION_TO_CALLBACK = {
|
|||||||
"clearml": ClearMLCallback,
|
"clearml": ClearMLCallback,
|
||||||
"dagshub": DagsHubCallback,
|
"dagshub": DagsHubCallback,
|
||||||
"flyte": FlyteCallback,
|
"flyte": FlyteCallback,
|
||||||
|
"dvclive": DVCLiveCallback,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -509,7 +509,7 @@ class TrainingArguments:
|
|||||||
instance of `Dataset`.
|
instance of `Dataset`.
|
||||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
||||||
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
||||||
integrations.
|
integrations.
|
||||||
ddp_find_unused_parameters (`bool`, *optional*):
|
ddp_find_unused_parameters (`bool`, *optional*):
|
||||||
@@ -2391,9 +2391,9 @@ class TrainingArguments:
|
|||||||
and lets the application set the level.
|
and lets the application set the level.
|
||||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
|
||||||
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
|
||||||
integrations.
|
`"none"` for no integrations.
|
||||||
first_step (`bool`, *optional*, defaults to `False`):
|
first_step (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to log and evaluate the first `global_step` or not.
|
Whether to log and evaluate the first `global_step` or not.
|
||||||
nan_inf_filter (`bool`, *optional*, defaults to `True`):
|
nan_inf_filter (`bool`, *optional*, defaults to `True`):
|
||||||
|
|||||||
Reference in New Issue
Block a user