TensorBoard/Wandb/optuna/raytune integration improvements. (#7935)
Improved TensorBoard and Wandb integration, as well as optuna and ray/tune support, with minor modifications to trainer core code.
This commit is contained in:
@@ -19,7 +19,7 @@ Callbacks to use with the Trainer class and customize the training loop.
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -66,6 +66,9 @@ class TrainerState:
|
||||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not this process is the global main process (when training in a distributed fashion on
|
||||
several machines, this is only going to be :obj:`True` for one process).
|
||||
is_hyper_param_search (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search.
|
||||
This will impact the way data will be logged in TensorBoard.
|
||||
"""
|
||||
|
||||
epoch: Optional[float] = None
|
||||
@@ -78,6 +81,9 @@ class TrainerState:
|
||||
best_model_checkpoint: Optional[str] = None
|
||||
is_local_process_zero: bool = True
|
||||
is_world_process_zero: bool = True
|
||||
is_hyper_param_search: bool = False
|
||||
trial_name: str = None
|
||||
trial_params: Dict[str, Union[str, float, int, bool]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.log_history is None:
|
||||
|
||||
Reference in New Issue
Block a user