Trainer + wandb quality of life logging tweaks (#6241)

* added `name` argument for wandb logging, also logging model config with trainer arguments

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added tf, post-review changes

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven
2020-08-05 15:05:52 +02:00
committed by GitHub
parent 33966811bd
commit bd0eab351a
4 changed files with 14 additions and 2 deletions

View File

@@ -383,7 +383,10 @@ class Trainer:
logger.info( logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
) )
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=self.args.to_sanitized_dict()) combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
# keep track of model topology and gradients, unsupported on TPU # keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch( wandb.watch(

View File

@@ -215,7 +215,8 @@ class TFTrainer:
return self._setup_wandb() return self._setup_wandb()
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args)) combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name)
def prediction_loop( def prediction_loop(
self, self,

View File

@@ -109,6 +109,8 @@ class TrainingArguments:
make use of the past hidden states for their predictions. If this argument is set to a positive int, the make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model ``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
""" """
output_dir: str = field( output_dir: str = field(
@@ -222,6 +224,10 @@ class TrainingArguments:
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."}, metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
) )
run_name: Optional[str] = field(
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
)
@property @property
def train_batch_size(self) -> int: def train_batch_size(self) -> int:
""" """

View File

@@ -95,6 +95,8 @@ class TFTrainingArguments(TrainingArguments):
at the next training step under the keyword argument ``mems``. at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`): tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on. The name of the TPU the process is running on.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
""" """
tpu_name: str = field( tpu_name: str = field(