Add hyperparameter search to Trainer (#6576)

* Add optuna hyperparameter search to Trainer

* @julien-c suggestions

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Make compute_objective an arg function

* Formatting

* Rework to make it easier to add ray

* Formatting

* Initial support for Ray

* Formatting

* Polish and finalize

* Add trial id to checkpoint with Ray

* Smaller default

* Use GPU in ray if available

* Formatting

* Fix test

* Update install instruction

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>

* Address review comments

* Formatting post-merge

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Sylvain Gugger
2020-08-24 11:48:45 -04:00
committed by GitHub
parent dd522da004
commit 3a7fdd3f52
5 changed files with 325 additions and 21 deletions

View File

@@ -114,6 +114,9 @@ class TrainingArguments:
at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
disable_tqdm (:obj:`bool`, `optional`):
Whether or not to disable the tqdm progress bars. Will default to :obj:`True` if the logging level is set
to warn or lower (default), :obj:`False` otherwise.
remove_unused_columns (:obj:`bool`, `optional`, defaults to :obj:`True`):
If using `nlp.Dataset` datasets, whether or not to automatically remove the columns unused by the model
forward method.
@@ -238,6 +241,13 @@ class TrainingArguments:
run_name: Optional[str] = field(
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
)
disable_tqdm: Optional[bool] = field(
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
)
def __post_init__(self):
if self.disable_tqdm is None:
self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
remove_unused_columns: Optional[bool] = field(
default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}