Adding optional trial argument to model_init (#7759)
* Adding optional trial argument to model_init Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -173,6 +173,9 @@ class Trainer:
|
||||
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
||||
A function that instantiates the model to be used. If provided, each call to
|
||||
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
||||
|
||||
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose
|
||||
different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
|
||||
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
||||
The function that will be used to compute metrics at evaluation. Must take a
|
||||
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
||||
@@ -212,15 +215,16 @@ class Trainer:
|
||||
assert (
|
||||
model is not None or model_init is not None
|
||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||
self.model_init = model_init
|
||||
if model is None and model_init is not None:
|
||||
model = model_init()
|
||||
model = self.call_model_init()
|
||||
self.model = model.to(args.device) if model is not None else None
|
||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.model_init = model_init
|
||||
|
||||
self.compute_metrics = compute_metrics
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
||||
@@ -532,6 +536,17 @@ class Trainer:
|
||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
def call_model_init(self, trial=None):
|
||||
model_init_argcount = len(inspect.signature(self.model_init).parameters)
|
||||
if model_init_argcount == 0:
|
||||
model = self.model_init()
|
||||
elif model_init_argcount == 1:
|
||||
model = self.model_init(trial)
|
||||
else:
|
||||
raise Exception("model_init should have 0 or 1 argument.")
|
||||
|
||||
return model
|
||||
|
||||
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
||||
"""
|
||||
Main training entry point.
|
||||
@@ -550,7 +565,9 @@ class Trainer:
|
||||
if self.model_init is not None:
|
||||
# Seed must be set before instantiating the model when using model_init.
|
||||
set_seed(self.args.seed)
|
||||
model = self.model_init()
|
||||
|
||||
model = self.call_model_init(trial)
|
||||
|
||||
self.model = model.to(self.args.device)
|
||||
|
||||
# Reinitializes optimizer and scheduler
|
||||
|
||||
Reference in New Issue
Block a user