Integrate SwanLab for offline/online experiment tracking and local visualization (#36433)
* add swanlab integration * feat(integrate): add SwanLab as an optional experiment tracking tool in transformers - Integrated SwanLab into the transformers library as an alternative for experiment tracking. - Users can now log training metrics, hyperparameters, and other experiment details to SwanLab by setting `report_to="swanlab"` in the `TrainingArguments`. - Added necessary dependencies and documentation for SwanLab integration. * Fix the spelling error of SwanLabCallback in callback.md * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Fix typo in comment * Fix typo in comment * Fix typos and update comments * fix annotation * chore: opt some comments --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: AAssets <20010618@qq.com> Co-authored-by: ZeYi Lin <944270057@qq.com> Co-authored-by: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com>
This commit is contained in:
@@ -45,6 +45,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
|
|||||||
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
|
||||||
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
|
||||||
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
|
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
|
||||||
|
- [`~integrations.SwanLabCallback`] if [swanlab](http://swanlab.cn/) is installed.
|
||||||
|
|
||||||
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
|
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
|
||||||
|
|
||||||
@@ -92,6 +93,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
|
|||||||
[[autodoc]] integrations.DVCLiveCallback
|
[[autodoc]] integrations.DVCLiveCallback
|
||||||
- setup
|
- setup
|
||||||
|
|
||||||
|
[[autodoc]] integrations.SwanLabCallback
|
||||||
|
- setup
|
||||||
|
|
||||||
## TrainerCallback
|
## TrainerCallback
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ rendered properly in your Markdown viewer.
|
|||||||
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
|
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
|
||||||
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
|
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
|
||||||
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
|
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
|
||||||
|
- [`~integrations.SwanLabCallback`] [swanlab](http://swanlab.cn/) がインストールされている場合。
|
||||||
|
|
||||||
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
|
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
|
||||||
|
|
||||||
@@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
|
|||||||
[[autodoc]] integrations.DVCLiveCallback
|
[[autodoc]] integrations.DVCLiveCallback
|
||||||
- setup
|
- setup
|
||||||
|
|
||||||
|
[[autodoc]] integrations.SwanLabCallback
|
||||||
|
- setup
|
||||||
|
|
||||||
## TrainerCallback
|
## TrainerCallback
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
|
|||||||
- [`~integrations.DagsHubCallback`]는 [dagshub](https://dagshub.com/)이 설치되어 있으면 사용됩니다.
|
- [`~integrations.DagsHubCallback`]는 [dagshub](https://dagshub.com/)이 설치되어 있으면 사용됩니다.
|
||||||
- [`~integrations.FlyteCallback`]는 [flyte](https://flyte.org/)가 설치되어 있으면 사용됩니다.
|
- [`~integrations.FlyteCallback`]는 [flyte](https://flyte.org/)가 설치되어 있으면 사용됩니다.
|
||||||
- [`~integrations.DVCLiveCallback`]는 [dvclive](https://dvc.org/doc/dvclive)가 설치되어 있으면 사용됩니다.
|
- [`~integrations.DVCLiveCallback`]는 [dvclive](https://dvc.org/doc/dvclive)가 설치되어 있으면 사용됩니다.
|
||||||
|
- [`~integrations.SwanLabCallback`]는 [swanlab](https://swanlab.cn)가 설치되어 있으면 사용됩니다.
|
||||||
|
|
||||||
패키지가 설치되어 있지만 해당 통합 기능을 사용하고 싶지 않다면, `TrainingArguments.report_to`를 사용하고자 하는 통합 기능 목록으로 변경할 수 있습니다 (예: `["azure_ml", "wandb"]`).
|
패키지가 설치되어 있지만 해당 통합 기능을 사용하고 싶지 않다면, `TrainingArguments.report_to`를 사용하고자 하는 통합 기능 목록으로 변경할 수 있습니다 (예: `["azure_ml", "wandb"]`).
|
||||||
|
|
||||||
@@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
|
|||||||
[[autodoc]] integrations.DVCLiveCallback
|
[[autodoc]] integrations.DVCLiveCallback
|
||||||
- setup
|
- setup
|
||||||
|
|
||||||
|
[[autodoc]] integrations.SwanLabCallback
|
||||||
|
- setup
|
||||||
|
|
||||||
## TrainerCallback [[trainercallback]]
|
## TrainerCallback [[trainercallback]]
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
|
|||||||
- [`~integrations.DagsHubCallback`],如果安装了[dagshub](https://dagshub.com/)。
|
- [`~integrations.DagsHubCallback`],如果安装了[dagshub](https://dagshub.com/)。
|
||||||
- [`~integrations.FlyteCallback`],如果安装了[flyte](https://flyte.org/)。
|
- [`~integrations.FlyteCallback`],如果安装了[flyte](https://flyte.org/)。
|
||||||
- [`~integrations.DVCLiveCallback`],如果安装了[dvclive](https://dvc.org/doc/dvclive)。
|
- [`~integrations.DVCLiveCallback`],如果安装了[dvclive](https://dvc.org/doc/dvclive)。
|
||||||
|
- [`~integrations.SwanLabCallback`],如果安装了[swanlab](http://swanlab.cn/)。
|
||||||
|
|
||||||
如果安装了一个软件包,但您不希望使用相关的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您想要使用的集成的列表(例如 `["azure_ml", "wandb"]`)。
|
如果安装了一个软件包,但您不希望使用相关的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您想要使用的集成的列表(例如 `["azure_ml", "wandb"]`)。
|
||||||
|
|
||||||
@@ -81,6 +82,9 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
|
|||||||
[[autodoc]] integrations.DVCLiveCallback
|
[[autodoc]] integrations.DVCLiveCallback
|
||||||
- setup
|
- setup
|
||||||
|
|
||||||
|
[[autodoc]] integrations.SwanLabCallback
|
||||||
|
- setup
|
||||||
|
|
||||||
## TrainerCallback
|
## TrainerCallback
|
||||||
|
|
||||||
[[autodoc]] TrainerCallback
|
[[autodoc]] TrainerCallback
|
||||||
|
|||||||
@@ -141,6 +141,7 @@ _import_structure = {
|
|||||||
"is_ray_available",
|
"is_ray_available",
|
||||||
"is_ray_tune_available",
|
"is_ray_tune_available",
|
||||||
"is_sigopt_available",
|
"is_sigopt_available",
|
||||||
|
"is_swanlab_available",
|
||||||
"is_tensorboard_available",
|
"is_tensorboard_available",
|
||||||
"is_wandb_available",
|
"is_wandb_available",
|
||||||
],
|
],
|
||||||
@@ -5267,6 +5268,7 @@ if TYPE_CHECKING:
|
|||||||
is_ray_available,
|
is_ray_available,
|
||||||
is_ray_tune_available,
|
is_ray_tune_available,
|
||||||
is_sigopt_available,
|
is_sigopt_available,
|
||||||
|
is_swanlab_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -63,7 +63,12 @@ _import_structure = {
|
|||||||
"load_dequant_gguf_tensor",
|
"load_dequant_gguf_tensor",
|
||||||
"load_gguf",
|
"load_gguf",
|
||||||
],
|
],
|
||||||
"higgs": ["HiggsLinear", "dequantize_higgs", "quantize_with_higgs", "replace_with_higgs_linear"],
|
"higgs": [
|
||||||
|
"HiggsLinear",
|
||||||
|
"dequantize_higgs",
|
||||||
|
"quantize_with_higgs",
|
||||||
|
"replace_with_higgs_linear",
|
||||||
|
],
|
||||||
"hqq": ["prepare_for_hqq_linear"],
|
"hqq": ["prepare_for_hqq_linear"],
|
||||||
"integration_utils": [
|
"integration_utils": [
|
||||||
"INTEGRATION_TO_CALLBACK",
|
"INTEGRATION_TO_CALLBACK",
|
||||||
@@ -77,6 +82,7 @@ _import_structure = {
|
|||||||
"MLflowCallback",
|
"MLflowCallback",
|
||||||
"NeptuneCallback",
|
"NeptuneCallback",
|
||||||
"NeptuneMissingConfiguration",
|
"NeptuneMissingConfiguration",
|
||||||
|
"SwanLabCallback",
|
||||||
"TensorBoardCallback",
|
"TensorBoardCallback",
|
||||||
"WandbCallback",
|
"WandbCallback",
|
||||||
"get_available_reporting_integrations",
|
"get_available_reporting_integrations",
|
||||||
@@ -96,6 +102,7 @@ _import_structure = {
|
|||||||
"is_ray_available",
|
"is_ray_available",
|
||||||
"is_ray_tune_available",
|
"is_ray_tune_available",
|
||||||
"is_sigopt_available",
|
"is_sigopt_available",
|
||||||
|
"is_swanlab_available",
|
||||||
"is_tensorboard_available",
|
"is_tensorboard_available",
|
||||||
"is_wandb_available",
|
"is_wandb_available",
|
||||||
"rewrite_logs",
|
"rewrite_logs",
|
||||||
@@ -182,6 +189,7 @@ if TYPE_CHECKING:
|
|||||||
MLflowCallback,
|
MLflowCallback,
|
||||||
NeptuneCallback,
|
NeptuneCallback,
|
||||||
NeptuneMissingConfiguration,
|
NeptuneMissingConfiguration,
|
||||||
|
SwanLabCallback,
|
||||||
TensorBoardCallback,
|
TensorBoardCallback,
|
||||||
WandbCallback,
|
WandbCallback,
|
||||||
get_available_reporting_integrations,
|
get_available_reporting_integrations,
|
||||||
@@ -201,6 +209,7 @@ if TYPE_CHECKING:
|
|||||||
is_ray_available,
|
is_ray_available,
|
||||||
is_ray_tune_available,
|
is_ray_tune_available,
|
||||||
is_sigopt_available,
|
is_sigopt_available,
|
||||||
|
is_swanlab_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
rewrite_logs,
|
rewrite_logs,
|
||||||
|
|||||||
@@ -204,6 +204,10 @@ def is_dvclive_available():
|
|||||||
return importlib.util.find_spec("dvclive") is not None
|
return importlib.util.find_spec("dvclive") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_swanlab_available():
|
||||||
|
return importlib.util.find_spec("swanlab") is not None
|
||||||
|
|
||||||
|
|
||||||
def hp_params(trial):
|
def hp_params(trial):
|
||||||
if is_optuna_available():
|
if is_optuna_available():
|
||||||
import optuna
|
import optuna
|
||||||
@@ -610,6 +614,8 @@ def get_available_reporting_integrations():
|
|||||||
integrations.append("codecarbon")
|
integrations.append("codecarbon")
|
||||||
if is_clearml_available():
|
if is_clearml_available():
|
||||||
integrations.append("clearml")
|
integrations.append("clearml")
|
||||||
|
if is_swanlab_available():
|
||||||
|
integrations.append("swanlab")
|
||||||
return integrations
|
return integrations
|
||||||
|
|
||||||
|
|
||||||
@@ -2141,6 +2147,162 @@ class DVCLiveCallback(TrainerCallback):
|
|||||||
self.live.end()
|
self.live.end()
|
||||||
|
|
||||||
|
|
||||||
|
class SwanLabCallback(TrainerCallback):
|
||||||
|
"""
|
||||||
|
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not is_swanlab_available():
|
||||||
|
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
|
||||||
|
import swanlab
|
||||||
|
|
||||||
|
self._swanlab = swanlab
|
||||||
|
self._initialized = False
|
||||||
|
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
|
||||||
|
|
||||||
|
def setup(self, args, state, model, **kwargs):
|
||||||
|
"""
|
||||||
|
Setup the optional SwanLab (*swanlab*) integration.
|
||||||
|
|
||||||
|
One can subclass and override this method to customize the setup if needed. Find more information
|
||||||
|
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
|
||||||
|
|
||||||
|
You can also override the following environment variables. Find more information about environment
|
||||||
|
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
|
||||||
|
|
||||||
|
Environment:
|
||||||
|
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
|
||||||
|
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
|
||||||
|
checks if the user is already logged in. If not, the login process is initiated.
|
||||||
|
|
||||||
|
- If a string is passed to the login interface, this environment variable is ignored.
|
||||||
|
- If the user is already logged in, this environment variable takes precedence over locally stored
|
||||||
|
login information.
|
||||||
|
|
||||||
|
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
|
||||||
|
Set this to a custom string to store results in a different project. If not specified, the name of the current
|
||||||
|
running directory is used.
|
||||||
|
|
||||||
|
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
|
||||||
|
This environment variable specifies the storage path for log files when running in local mode.
|
||||||
|
By default, logs are saved in a folder named swanlog under the working directory.
|
||||||
|
|
||||||
|
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
|
||||||
|
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
|
||||||
|
local, cloud, and disabled. Note: Case-sensitive. Find more information
|
||||||
|
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
|
||||||
|
|
||||||
|
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
|
||||||
|
SwanLab does not currently support the save mode functionality.This feature will be available in a future
|
||||||
|
release
|
||||||
|
|
||||||
|
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
|
||||||
|
Web address for the SwanLab cloud environment for private version (its free)
|
||||||
|
|
||||||
|
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
|
||||||
|
API address for the SwanLab cloud environment for private version (its free)
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
|
||||||
|
combined_dict = {**args.to_dict()}
|
||||||
|
|
||||||
|
if hasattr(model, "config") and model.config is not None:
|
||||||
|
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
|
||||||
|
combined_dict = {**model_config, **combined_dict}
|
||||||
|
if hasattr(model, "peft_config") and model.peft_config is not None:
|
||||||
|
peft_config = model.peft_config
|
||||||
|
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
|
||||||
|
trial_name = state.trial_name
|
||||||
|
init_args = {}
|
||||||
|
if trial_name is not None:
|
||||||
|
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
|
||||||
|
elif args.run_name is not None:
|
||||||
|
init_args["experiment_name"] = args.run_name
|
||||||
|
init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
|
||||||
|
|
||||||
|
if self._swanlab.get_run() is None:
|
||||||
|
self._swanlab.init(
|
||||||
|
**init_args,
|
||||||
|
)
|
||||||
|
# show transformers logo!
|
||||||
|
self._swanlab.config["FRAMEWORK"] = "🤗transformers"
|
||||||
|
# add config parameters (run may have been created manually)
|
||||||
|
self._swanlab.config.update(combined_dict)
|
||||||
|
|
||||||
|
# add number of model parameters to swanlab config
|
||||||
|
try:
|
||||||
|
self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
|
||||||
|
# get peft model parameters
|
||||||
|
if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
|
||||||
|
trainable_params, all_param = model.get_nb_trainable_parameters()
|
||||||
|
self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
|
||||||
|
self._swanlab.config.update({"peft_model_all_param": all_param})
|
||||||
|
except AttributeError:
|
||||||
|
logger.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")
|
||||||
|
|
||||||
|
# log the initial model architecture to an artifact
|
||||||
|
if self._log_model is not None:
|
||||||
|
logger.warning(
|
||||||
|
"SwanLab does not currently support the save mode functionality. "
|
||||||
|
"This feature will be available in a future release."
|
||||||
|
)
|
||||||
|
badge_markdown = (
|
||||||
|
f'[<img src="https://raw.githubusercontent.com/SwanHubX/assets/main/badge1.svg"'
|
||||||
|
f' alt="Visualize in SwanLab" height="28'
|
||||||
|
f'0" height="32"/>]({self._swanlab.get_run().public.cloud.exp_url})'
|
||||||
|
)
|
||||||
|
|
||||||
|
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
|
def on_train_begin(self, args, state, control, model=None, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model, **kwargs)
|
||||||
|
|
||||||
|
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
|
||||||
|
if self._log_model is not None and self._initialized and state.is_world_process_zero:
|
||||||
|
logger.warning(
|
||||||
|
"SwanLab does not currently support the save mode functionality. "
|
||||||
|
"This feature will be available in a future release."
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
|
||||||
|
single_value_scalars = [
|
||||||
|
"train_runtime",
|
||||||
|
"train_samples_per_second",
|
||||||
|
"train_steps_per_second",
|
||||||
|
"train_loss",
|
||||||
|
"total_flos",
|
||||||
|
]
|
||||||
|
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, model)
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
for k, v in logs.items():
|
||||||
|
if k in single_value_scalars:
|
||||||
|
self._swanlab.log({f"single_value/{k}": v})
|
||||||
|
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
|
||||||
|
non_scalar_logs = rewrite_logs(non_scalar_logs)
|
||||||
|
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step})
|
||||||
|
|
||||||
|
def on_save(self, args, state, control, **kwargs):
|
||||||
|
if self._log_model is not None and self._initialized and state.is_world_process_zero:
|
||||||
|
logger.warning(
|
||||||
|
"SwanLab does not currently support the save mode functionality. "
|
||||||
|
"This feature will be available in a future release."
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_predict(self, args, state, control, metrics, **kwargs):
|
||||||
|
if not self._initialized:
|
||||||
|
self.setup(args, state, **kwargs)
|
||||||
|
if state.is_world_process_zero:
|
||||||
|
metrics = rewrite_logs(metrics)
|
||||||
|
self._swanlab.log(metrics)
|
||||||
|
|
||||||
|
|
||||||
INTEGRATION_TO_CALLBACK = {
|
INTEGRATION_TO_CALLBACK = {
|
||||||
"azure_ml": AzureMLCallback,
|
"azure_ml": AzureMLCallback,
|
||||||
"comet_ml": CometCallback,
|
"comet_ml": CometCallback,
|
||||||
@@ -2153,6 +2315,7 @@ INTEGRATION_TO_CALLBACK = {
|
|||||||
"dagshub": DagsHubCallback,
|
"dagshub": DagsHubCallback,
|
||||||
"flyte": FlyteCallback,
|
"flyte": FlyteCallback,
|
||||||
"dvclive": DVCLiveCallback,
|
"dvclive": DVCLiveCallback,
|
||||||
|
"swanlab": SwanLabCallback,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from .integrations import (
|
|||||||
is_optuna_available,
|
is_optuna_available,
|
||||||
is_ray_available,
|
is_ray_available,
|
||||||
is_sigopt_available,
|
is_sigopt_available,
|
||||||
|
is_swanlab_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
)
|
)
|
||||||
@@ -1098,6 +1099,16 @@ def require_sigopt(test_case):
|
|||||||
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
|
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_swanlab(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires swanlab.
|
||||||
|
|
||||||
|
These tests are skipped when swanlab isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def require_wandb(test_case):
|
def require_wandb(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires wandb.
|
Decorator marking a test that requires wandb.
|
||||||
|
|||||||
@@ -451,8 +451,8 @@ class TrainingArguments:
|
|||||||
training step under the keyword argument `mems`.
|
training step under the keyword argument `mems`.
|
||||||
run_name (`str`, *optional*, defaults to `output_dir`):
|
run_name (`str`, *optional*, defaults to `output_dir`):
|
||||||
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),
|
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),
|
||||||
[mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will
|
[mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and [swanlab](https://swanlab.cn)
|
||||||
be the same as `output_dir`.
|
logging. If not specified, will be the same as `output_dir`.
|
||||||
disable_tqdm (`bool`, *optional*):
|
disable_tqdm (`bool`, *optional*):
|
||||||
Whether or not to disable the tqdm progress bars and table of metrics produced by
|
Whether or not to disable the tqdm progress bars and table of metrics produced by
|
||||||
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
|
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
|
||||||
@@ -642,8 +642,8 @@ class TrainingArguments:
|
|||||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
|
||||||
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
|
`"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"`
|
||||||
integrations.
|
for no integrations.
|
||||||
ddp_find_unused_parameters (`bool`, *optional*):
|
ddp_find_unused_parameters (`bool`, *optional*):
|
||||||
When using distributed training, the value of the flag `find_unused_parameters` passed to
|
When using distributed training, the value of the flag `find_unused_parameters` passed to
|
||||||
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
|
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
|
||||||
@@ -1187,7 +1187,9 @@ class TrainingArguments:
|
|||||||
|
|
||||||
run_name: Optional[str] = field(
|
run_name: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "An optional descriptor for the run. Notably used for wandb, mlflow and comet logging."},
|
metadata={
|
||||||
|
"help": "An optional descriptor for the run. Notably used for wandb, mlflow comet and swanlab logging."
|
||||||
|
},
|
||||||
)
|
)
|
||||||
disable_tqdm: Optional[bool] = field(
|
disable_tqdm: Optional[bool] = field(
|
||||||
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
|
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
|
||||||
@@ -2848,8 +2850,8 @@ class TrainingArguments:
|
|||||||
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
|
||||||
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
|
||||||
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
|
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
|
||||||
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
|
`"neptune"`, `"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations
|
||||||
`"none"` for no integrations.
|
installed, `"none"` for no integrations.
|
||||||
first_step (`bool`, *optional*, defaults to `False`):
|
first_step (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to log and evaluate the first `global_step` or not.
|
Whether to log and evaluate the first `global_step` or not.
|
||||||
nan_inf_filter (`bool`, *optional*, defaults to `True`):
|
nan_inf_filter (`bool`, *optional*, defaults to `True`):
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
|
Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
|
||||||
automatically detect from metadata.
|
automatically detect from metadata.
|
||||||
run_name (`str`, *optional*):
|
run_name (`str`, *optional*):
|
||||||
A descriptor for the run. Notably used for wandb, mlflow and comet logging.
|
A descriptor for the run. Notably used for wandb, mlflow, comet and swanlab logging.
|
||||||
xla (`bool`, *optional*):
|
xla (`bool`, *optional*):
|
||||||
Whether to activate the XLA compilation or not.
|
Whether to activate the XLA compilation or not.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user