Trainer + wandb quality of life logging tweaks (#6241)
* added `name` argument for wandb logging, also logging model config with trainer arguments * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * added tf, post-review changes Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -383,7 +383,10 @@ class Trainer:
|
|||||||
logger.info(
|
logger.info(
|
||||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||||
)
|
)
|
||||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=self.args.to_sanitized_dict())
|
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||||
|
wandb.init(
|
||||||
|
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
||||||
|
)
|
||||||
# keep track of model topology and gradients, unsupported on TPU
|
# keep track of model topology and gradients, unsupported on TPU
|
||||||
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
||||||
wandb.watch(
|
wandb.watch(
|
||||||
|
|||||||
@@ -215,7 +215,8 @@ class TFTrainer:
|
|||||||
return self._setup_wandb()
|
return self._setup_wandb()
|
||||||
|
|
||||||
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
|
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
|
||||||
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
|
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||||
|
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name)
|
||||||
|
|
||||||
def prediction_loop(
|
def prediction_loop(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -109,6 +109,8 @@ class TrainingArguments:
|
|||||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||||
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
|
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
|
||||||
at the next training step under the keyword argument ``mems``.
|
at the next training step under the keyword argument ``mems``.
|
||||||
|
run_name (:obj:`str`, `optional`):
|
||||||
|
A descriptor for the run. Notably used for wandb logging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
output_dir: str = field(
|
output_dir: str = field(
|
||||||
@@ -222,6 +224,10 @@ class TrainingArguments:
|
|||||||
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
|
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
run_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_batch_size(self) -> int:
|
def train_batch_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
at the next training step under the keyword argument ``mems``.
|
at the next training step under the keyword argument ``mems``.
|
||||||
tpu_name (:obj:`str`, `optional`):
|
tpu_name (:obj:`str`, `optional`):
|
||||||
The name of the TPU the process is running on.
|
The name of the TPU the process is running on.
|
||||||
|
run_name (:obj:`str`, `optional`):
|
||||||
|
A descriptor for the run. Notably used for wandb logging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tpu_name: str = field(
|
tpu_name: str = field(
|
||||||
|
|||||||
Reference in New Issue
Block a user