Adding optional trial argument to model_init (#7759)
* Adding optional trial argument to model_init Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -173,6 +173,9 @@ class Trainer:
|
|||||||
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
||||||
A function that instantiates the model to be used. If provided, each call to
|
A function that instantiates the model to be used. If provided, each call to
|
||||||
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
||||||
|
|
||||||
|
The function may have zero argument, or a single one containing the optuna/Ray Tune trial object, to be able to choose
|
||||||
|
different architectures according to hyper parameters (such as layer count, sizes of inner layers, dropout probabilities etc).
|
||||||
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
||||||
The function that will be used to compute metrics at evaluation. Must take a
|
The function that will be used to compute metrics at evaluation. Must take a
|
||||||
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
||||||
@@ -212,15 +215,16 @@ class Trainer:
|
|||||||
assert (
|
assert (
|
||||||
model is not None or model_init is not None
|
model is not None or model_init is not None
|
||||||
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
||||||
|
self.model_init = model_init
|
||||||
if model is None and model_init is not None:
|
if model is None and model_init is not None:
|
||||||
model = model_init()
|
model = self.call_model_init()
|
||||||
self.model = model.to(args.device) if model is not None else None
|
self.model = model.to(args.device) if model is not None else None
|
||||||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||||
self.data_collator = data_collator if data_collator is not None else default_collator
|
self.data_collator = data_collator if data_collator is not None else default_collator
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.eval_dataset = eval_dataset
|
self.eval_dataset = eval_dataset
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.model_init = model_init
|
|
||||||
self.compute_metrics = compute_metrics
|
self.compute_metrics = compute_metrics
|
||||||
self.optimizer, self.lr_scheduler = optimizers
|
self.optimizer, self.lr_scheduler = optimizers
|
||||||
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
||||||
@@ -532,6 +536,17 @@ class Trainer:
|
|||||||
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
||||||
|
|
||||||
|
def call_model_init(self, trial=None):
|
||||||
|
model_init_argcount = len(inspect.signature(self.model_init).parameters)
|
||||||
|
if model_init_argcount == 0:
|
||||||
|
model = self.model_init()
|
||||||
|
elif model_init_argcount == 1:
|
||||||
|
model = self.model_init(trial)
|
||||||
|
else:
|
||||||
|
raise Exception("model_init should have 0 or 1 argument.")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
||||||
"""
|
"""
|
||||||
Main training entry point.
|
Main training entry point.
|
||||||
@@ -550,7 +565,9 @@ class Trainer:
|
|||||||
if self.model_init is not None:
|
if self.model_init is not None:
|
||||||
# Seed must be set before instantiating the model when using model_init.
|
# Seed must be set before instantiating the model when using model_init.
|
||||||
set_seed(self.args.seed)
|
set_seed(self.args.seed)
|
||||||
model = self.model_init()
|
|
||||||
|
model = self.call_model_init(trial)
|
||||||
|
|
||||||
self.model = model.to(self.args.device)
|
self.model = model.to(self.args.device)
|
||||||
|
|
||||||
# Reinitializes optimizer and scheduler
|
# Reinitializes optimizer and scheduler
|
||||||
|
|||||||
Reference in New Issue
Block a user