Integrate SwanLab for offline/online experiment tracking and local visualization (#36433)
* add swanlab integration * feat(integrate): add SwanLab as an optional experiment tracking tool in transformers - Integrated SwanLab into the transformers library as an alternative for experiment tracking. - Users can now log training metrics, hyperparameters, and other experiment details to SwanLab by setting `report_to="swanlab"` in the `TrainingArguments`. - Added necessary dependencies and documentation for SwanLab integration. * Fix the spelling error of SwanLabCallback in callback.md * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Fix typo in comment * Fix typo in comment * Fix typos and update comments * fix annotation * chore: opt some comments --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: AAssets <20010618@qq.com> Co-authored-by: ZeYi Lin <944270057@qq.com> Co-authored-by: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
|
||||
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
||||
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
||||
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
|
||||
- [`~integrations.SwanLabCallback`] if [swanlab](http://swanlab.cn/) is installed.
|
||||
|
||||
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
|
||||
|
||||
@@ -92,6 +93,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
|
||||
[[autodoc]] integrations.DVCLiveCallback
|
||||
- setup
|
||||
|
||||
[[autodoc]] integrations.SwanLabCallback
|
||||
- setup
|
||||
|
||||
## TrainerCallback
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
||||
@@ -46,6 +46,7 @@ rendered properly in your Markdown viewer.
|
||||
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
|
||||
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
|
||||
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
|
||||
- [`~integrations.SwanLabCallback`] [swanlab](http://swanlab.cn/) がインストールされている場合。
|
||||
|
||||
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
|
||||
|
||||
@@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
|
||||
[[autodoc]] integrations.DVCLiveCallback
|
||||
- setup
|
||||
|
||||
[[autodoc]] integrations.SwanLabCallback
|
||||
- setup
|
||||
|
||||
## TrainerCallback
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
||||
@@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
|
||||
- [`~integrations.DagsHubCallback`]는 [dagshub](https://dagshub.com/)이 설치되어 있으면 사용됩니다.
|
||||
- [`~integrations.FlyteCallback`]는 [flyte](https://flyte.org/)가 설치되어 있으면 사용됩니다.
|
||||
- [`~integrations.DVCLiveCallback`]는 [dvclive](https://dvc.org/doc/dvclive)가 설치되어 있으면 사용됩니다.
|
||||
- [`~integrations.SwanLabCallback`]는 [swanlab](https://swanlab.cn)가 설치되어 있으면 사용됩니다.
|
||||
|
||||
패키지가 설치되어 있지만 해당 통합 기능을 사용하고 싶지 않다면, `TrainingArguments.report_to`를 사용하고자 하는 통합 기능 목록으로 변경할 수 있습니다 (예: `["azure_ml", "wandb"]`).
|
||||
|
||||
@@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
|
||||
[[autodoc]] integrations.DVCLiveCallback
|
||||
- setup
|
||||
|
||||
[[autodoc]] integrations.SwanLabCallback
|
||||
- setup
|
||||
|
||||
## TrainerCallback [[trainercallback]]
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
||||
@@ -37,6 +37,7 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
|
||||
- [`~integrations.DagsHubCallback`],如果安装了[dagshub](https://dagshub.com/)。
|
||||
- [`~integrations.FlyteCallback`],如果安装了[flyte](https://flyte.org/)。
|
||||
- [`~integrations.DVCLiveCallback`],如果安装了[dvclive](https://dvc.org/doc/dvclive)。
|
||||
- [`~integrations.SwanLabCallback`],如果安装了[swanlab](http://swanlab.cn/)。
|
||||
|
||||
如果安装了一个软件包,但您不希望使用相关的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您想要使用的集成的列表(例如 `["azure_ml", "wandb"]`)。
|
||||
|
||||
@@ -81,6 +82,9 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
|
||||
[[autodoc]] integrations.DVCLiveCallback
|
||||
- setup
|
||||
|
||||
[[autodoc]] integrations.SwanLabCallback
|
||||
- setup
|
||||
|
||||
## TrainerCallback
|
||||
|
||||
[[autodoc]] TrainerCallback
|
||||
|
||||
@@ -141,6 +141,7 @@ _import_structure = {
|
||||
"is_ray_available",
|
||||
"is_ray_tune_available",
|
||||
"is_sigopt_available",
|
||||
"is_swanlab_available",
|
||||
"is_tensorboard_available",
|
||||
"is_wandb_available",
|
||||
],
|
||||
@@ -5267,6 +5268,7 @@ if TYPE_CHECKING:
|
||||
is_ray_available,
|
||||
is_ray_tune_available,
|
||||
is_sigopt_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
|
||||
@@ -63,7 +63,12 @@ _import_structure = {
|
||||
"load_dequant_gguf_tensor",
|
||||
"load_gguf",
|
||||
],
|
||||
"higgs": ["HiggsLinear", "dequantize_higgs", "quantize_with_higgs", "replace_with_higgs_linear"],
|
||||
"higgs": [
|
||||
"HiggsLinear",
|
||||
"dequantize_higgs",
|
||||
"quantize_with_higgs",
|
||||
"replace_with_higgs_linear",
|
||||
],
|
||||
"hqq": ["prepare_for_hqq_linear"],
|
||||
"integration_utils": [
|
||||
"INTEGRATION_TO_CALLBACK",
|
||||
@@ -77,6 +82,7 @@ _import_structure = {
|
||||
"MLflowCallback",
|
||||
"NeptuneCallback",
|
||||
"NeptuneMissingConfiguration",
|
||||
"SwanLabCallback",
|
||||
"TensorBoardCallback",
|
||||
"WandbCallback",
|
||||
"get_available_reporting_integrations",
|
||||
@@ -96,6 +102,7 @@ _import_structure = {
|
||||
"is_ray_available",
|
||||
"is_ray_tune_available",
|
||||
"is_sigopt_available",
|
||||
"is_swanlab_available",
|
||||
"is_tensorboard_available",
|
||||
"is_wandb_available",
|
||||
"rewrite_logs",
|
||||
@@ -182,6 +189,7 @@ if TYPE_CHECKING:
|
||||
MLflowCallback,
|
||||
NeptuneCallback,
|
||||
NeptuneMissingConfiguration,
|
||||
SwanLabCallback,
|
||||
TensorBoardCallback,
|
||||
WandbCallback,
|
||||
get_available_reporting_integrations,
|
||||
@@ -201,6 +209,7 @@ if TYPE_CHECKING:
|
||||
is_ray_available,
|
||||
is_ray_tune_available,
|
||||
is_sigopt_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
rewrite_logs,
|
||||
|
||||
@@ -204,6 +204,10 @@ def is_dvclive_available():
|
||||
return importlib.util.find_spec("dvclive") is not None
|
||||
|
||||
|
||||
def is_swanlab_available():
|
||||
return importlib.util.find_spec("swanlab") is not None
|
||||
|
||||
|
||||
def hp_params(trial):
|
||||
if is_optuna_available():
|
||||
import optuna
|
||||
@@ -610,6 +614,8 @@ def get_available_reporting_integrations():
|
||||
integrations.append("codecarbon")
|
||||
if is_clearml_available():
|
||||
integrations.append("clearml")
|
||||
if is_swanlab_available():
|
||||
integrations.append("swanlab")
|
||||
return integrations
|
||||
|
||||
|
||||
@@ -2141,6 +2147,162 @@ class DVCLiveCallback(TrainerCallback):
|
||||
self.live.end()
|
||||
|
||||
|
||||
class SwanLabCallback(TrainerCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not is_swanlab_available():
|
||||
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
|
||||
import swanlab
|
||||
|
||||
self._swanlab = swanlab
|
||||
self._initialized = False
|
||||
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
|
||||
|
||||
def setup(self, args, state, model, **kwargs):
|
||||
"""
|
||||
Setup the optional SwanLab (*swanlab*) integration.
|
||||
|
||||
One can subclass and override this method to customize the setup if needed. Find more information
|
||||
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
|
||||
|
||||
You can also override the following environment variables. Find more information about environment
|
||||
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
|
||||
|
||||
Environment:
|
||||
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
|
||||
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
|
||||
checks if the user is already logged in. If not, the login process is initiated.
|
||||
|
||||
- If a string is passed to the login interface, this environment variable is ignored.
|
||||
- If the user is already logged in, this environment variable takes precedence over locally stored
|
||||
login information.
|
||||
|
||||
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
|
||||
Set this to a custom string to store results in a different project. If not specified, the name of the current
|
||||
running directory is used.
|
||||
|
||||
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
|
||||
This environment variable specifies the storage path for log files when running in local mode.
|
||||
By default, logs are saved in a folder named swanlog under the working directory.
|
||||
|
||||
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
|
||||
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
|
||||
local, cloud, and disabled. Note: Case-sensitive. Find more information
|
||||
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
|
||||
|
||||
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
|
||||
SwanLab does not currently support the save mode functionality.This feature will be available in a future
|
||||
release
|
||||
|
||||
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
|
||||
Web address for the SwanLab cloud environment for private version (its free)
|
||||
|
||||
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
|
||||
API address for the SwanLab cloud environment for private version (its free)
|
||||
|
||||
"""
|
||||
self._initialized = True
|
||||
|
||||
if state.is_world_process_zero:
|
||||
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
|
||||
combined_dict = {**args.to_dict()}
|
||||
|
||||
if hasattr(model, "config") and model.config is not None:
|
||||
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
|
||||
combined_dict = {**model_config, **combined_dict}
|
||||
if hasattr(model, "peft_config") and model.peft_config is not None:
|
||||
peft_config = model.peft_config
|
||||
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
|
||||
trial_name = state.trial_name
|
||||
init_args = {}
|
||||
if trial_name is not None:
|
||||
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
|
||||
elif args.run_name is not None:
|
||||
init_args["experiment_name"] = args.run_name
|
||||
init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
|
||||
|
||||
if self._swanlab.get_run() is None:
|
||||
self._swanlab.init(
|
||||
**init_args,
|
||||
)
|
||||
# show transformers logo!
|
||||
self._swanlab.config["FRAMEWORK"] = "🤗transformers"
|
||||
# add config parameters (run may have been created manually)
|
||||
self._swanlab.config.update(combined_dict)
|
||||
|
||||
# add number of model parameters to swanlab config
|
||||
try:
|
||||
self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
|
||||
# get peft model parameters
|
||||
if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
|
||||
trainable_params, all_param = model.get_nb_trainable_parameters()
|
||||
self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
|
||||
self._swanlab.config.update({"peft_model_all_param": all_param})
|
||||
except AttributeError:
|
||||
logger.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")
|
||||
|
||||
# log the initial model architecture to an artifact
|
||||
if self._log_model is not None:
|
||||
logger.warning(
|
||||
"SwanLab does not currently support the save mode functionality. "
|
||||
"This feature will be available in a future release."
|
||||
)
|
||||
badge_markdown = (
|
||||
f'[<img src="https://raw.githubusercontent.com/SwanHubX/assets/main/badge1.svg"'
|
||||
f' alt="Visualize in SwanLab" height="28'
|
||||
f'0" height="32"/>]({self._swanlab.get_run().public.cloud.exp_url})'
|
||||
)
|
||||
|
||||
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||
|
||||
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model, **kwargs)
|
||||
|
||||
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
||||
if self._log_model is not None and self._initialized and state.is_world_process_zero:
|
||||
logger.warning(
|
||||
"SwanLab does not currently support the save mode functionality. "
|
||||
"This feature will be available in a future release."
|
||||
)
|
||||
|
||||
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||
single_value_scalars = [
|
||||
"train_runtime",
|
||||
"train_samples_per_second",
|
||||
"train_steps_per_second",
|
||||
"train_loss",
|
||||
"total_flos",
|
||||
]
|
||||
|
||||
if not self._initialized:
|
||||
self.setup(args, state, model)
|
||||
if state.is_world_process_zero:
|
||||
for k, v in logs.items():
|
||||
if k in single_value_scalars:
|
||||
self._swanlab.log({f"single_value/{k}": v})
|
||||
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
|
||||
non_scalar_logs = rewrite_logs(non_scalar_logs)
|
||||
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step})
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
if self._log_model is not None and self._initialized and state.is_world_process_zero:
|
||||
logger.warning(
|
||||
"SwanLab does not currently support the save mode functionality. "
|
||||
"This feature will be available in a future release."
|
||||
)
|
||||
|
||||
def on_predict(self, args, state, control, metrics, **kwargs):
|
||||
if not self._initialized:
|
||||
self.setup(args, state, **kwargs)
|
||||
if state.is_world_process_zero:
|
||||
metrics = rewrite_logs(metrics)
|
||||
self._swanlab.log(metrics)
|
||||
|
||||
|
||||
INTEGRATION_TO_CALLBACK = {
|
||||
"azure_ml": AzureMLCallback,
|
||||
"comet_ml": CometCallback,
|
||||
@@ -2153,6 +2315,7 @@ INTEGRATION_TO_CALLBACK = {
|
||||
"dagshub": DagsHubCallback,
|
||||
"flyte": FlyteCallback,
|
||||
"dvclive": DVCLiveCallback,
|
||||
"swanlab": SwanLabCallback,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ from .integrations import (
|
||||
is_optuna_available,
|
||||
is_ray_available,
|
||||
is_sigopt_available,
|
||||
is_swanlab_available,
|
||||
is_tensorboard_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
@@ -1098,6 +1099,16 @@ def require_sigopt(test_case):
|
||||
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
|
||||
|
||||
|
||||
def require_swanlab(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires swanlab.
|
||||
|
||||
These tests are skipped when swanlab isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
|
||||
|
||||
|
||||
def require_wandb(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires wandb.
|
||||
|
||||
@@ -451,8 +451,8 @@ class TrainingArguments:
|
||||
training step under the keyword argument `mems`.
|
||||
run_name (`str`, *optional*, defaults to `output_dir`):
|
||||
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),
|
||||
[mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will
|
||||
be the same as `output_dir`.
|
||||
[mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and [swanlab](https://swanlab.cn)
|
||||
logging. If not specified, will be the same as `output_dir`.
|
||||
disable_tqdm (`bool`, *optional*):
|
||||
Whether or not to disable the tqdm progress bars and table of metrics produced by
|
||||
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
|
||||
@@ -642,8 +642,8 @@ class TrainingArguments:
|
||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
||||
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
||||
integrations.
|
||||
`"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"`
|
||||
for no integrations.
|
||||
ddp_find_unused_parameters (`bool`, *optional*):
|
||||
When using distributed training, the value of the flag `find_unused_parameters` passed to
|
||||
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
|
||||
@@ -1187,7 +1187,9 @@ class TrainingArguments:
|
||||
|
||||
run_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "An optional descriptor for the run. Notably used for wandb, mlflow and comet logging."},
|
||||
metadata={
|
||||
"help": "An optional descriptor for the run. Notably used for wandb, mlflow comet and swanlab logging."
|
||||
},
|
||||
)
|
||||
disable_tqdm: Optional[bool] = field(
|
||||
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
|
||||
@@ -2848,8 +2850,8 @@ class TrainingArguments:
|
||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
|
||||
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
|
||||
`"none"` for no integrations.
|
||||
`"neptune"`, `"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations
|
||||
installed, `"none"` for no integrations.
|
||||
first_step (`bool`, *optional*, defaults to `False`):
|
||||
Whether to log and evaluate the first `global_step` or not.
|
||||
nan_inf_filter (`bool`, *optional*, defaults to `True`):
|
||||
|
||||
@@ -160,7 +160,7 @@ class TFTrainingArguments(TrainingArguments):
|
||||
Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
|
||||
automatically detect from metadata.
|
||||
run_name (`str`, *optional*):
|
||||
A descriptor for the run. Notably used for wandb, mlflow and comet logging.
|
||||
A descriptor for the run. Notably used for wandb, mlflow, comet and swanlab logging.
|
||||
xla (`bool`, *optional*):
|
||||
Whether to activate the XLA compilation or not.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user