Adds train_batch_size, eval_batch_size, and n_gpu to to_sanitized_dict output for logging. (#5331)

* Adds train_batch_size, eval_batch_size, and n_gpu to to_sanitized_dict() output

* Update wandb config logging to use to_sanitized_dict

* removed n_gpu from sanitized dict

* fix quality check errors
This commit is contained in:
Jay Mody
2020-08-03 09:00:39 -04:00
committed by GitHub
parent 9996f697e3
commit cedc547e7e
2 changed files with 4 additions and 1 deletions

View File

@@ -310,7 +310,10 @@ class TrainingArguments:
Sanitized serialization to use with TensorBoards hparams
"""
d = dataclasses.asdict(self)
d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
valid_types = [bool, int, float, str]
if is_torch_available():
valid_types.append(torch.Tensor)
return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}