From 2d6e2ad4fac12efd416994de13fbac298128f229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Lagunas?= Date: Tue, 13 Oct 2020 17:07:02 +0200 Subject: [PATCH] Adding optional trial argument to model_init (#7759) * Adding optional trial argument to model_init Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/trainer.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3656ee1bc3..96451daa5d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -173,6 +173,9 @@ class Trainer: model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`): A function that instantiates the model to be used. If provided, each call to :meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function. + + The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose + different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc). compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): The function that will be used to compute metrics at evaluation. Must take a :class:`~transformers.EvalPrediction` and return a dictionary string to metric values. @@ -212,15 +215,16 @@ class Trainer: assert ( model is not None or model_init is not None ), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument." + self.model_init = model_init if model is None and model_init is not None: - model = model_init() + model = self.call_model_init() self.model = model.to(args.device) if model is not None else None default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.tokenizer = tokenizer - self.model_init = model_init + self.compute_metrics = compute_metrics self.optimizer, self.lr_scheduler = optimizers if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): @@ -532,6 +536,17 @@ class Trainer: torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) + def call_model_init(self, trial=None): + model_init_argcount = len(inspect.signature(self.model_init).parameters) + if model_init_argcount == 0: + model = self.model_init() + elif model_init_argcount == 1: + model = self.model_init(trial) + else: + raise Exception("model_init should have 0 or 1 argument.") + + return model + def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None): """ Main training entry point. @@ -550,7 +565,9 @@ class Trainer: if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. set_seed(self.args.seed) - model = self.model_init() + + model = self.call_model_init(trial) + self.model = model.to(self.args.device) # Reinitializes optimizer and scheduler