From d7f38c5d1d9e4c5bd1f766a54861d3f426ee207b Mon Sep 17 00:00:00 2001 From: Tanmay Garg <32801726+tanmay17061@users.noreply.github.com> Date: Thu, 18 Feb 2021 22:53:33 +0530 Subject: [PATCH] Introduce warmup_ratio training argument (#10229) Introduce warmup_ratio training argument in both TrainingArguments and TFTrainingArguments classes (#6673) --- src/transformers/trainer.py | 8 +++++++- src/transformers/trainer_tf.py | 8 +++++++- src/transformers/training_args.py | 15 ++++++++++++++- src/transformers/training_args_tf.py | 5 ++++- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index e1febc0b03..f5ad1219ea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -615,10 +615,16 @@ class Trainer: self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) if self.lr_scheduler is None: + warmup_steps = ( + self.args.warmup_steps + if self.args.warmup_steps > 0 + else math.ceil(num_training_steps * self.args.warmup_ratio) + ) + self.lr_scheduler = get_scheduler( self.args.lr_scheduler_type, self.optimizer, - num_warmup_steps=self.args.warmup_steps, + num_warmup_steps=warmup_steps, num_training_steps=num_training_steps, ) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index adc08e6c7c..509d8b77f1 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -218,10 +218,16 @@ class TFTrainer: TFTrainer's init through :obj:`optimizers`, or subclass and override this method. """ if not self.optimizer and not self.lr_scheduler: + warmup_steps = ( + self.args.warmup_steps + if self.args.warmup_steps > 0 + else math.ceil(num_training_steps * self.args.warmup_ratio) + ) + self.optimizer, self.lr_scheduler = create_optimizer( self.args.learning_rate, num_training_steps, - self.args.warmup_steps, + warmup_steps, adam_beta1=self.args.adam_beta1, adam_beta2=self.args.adam_beta2, adam_epsilon=self.args.adam_epsilon, diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 396cef24f9..4f9abba0d9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -131,8 +131,11 @@ class TrainingArguments: lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`): The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible values. + warmup_ratio (:obj:`float`, `optional`, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`. warmup_steps (:obj:`int`, `optional`, defaults to 0): - Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. + Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of + :obj:`warmup_ratio`. logging_dir (:obj:`str`, `optional`): `TensorBoard `__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`. @@ -324,6 +327,9 @@ class TrainingArguments: default="linear", metadata={"help": "The scheduler type to use."}, ) + warmup_ratio: float = field( + default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} + ) warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."}) @@ -495,6 +501,13 @@ class TrainingArguments: elif not isinstance(self.report_to, list): self.report_to = [self.report_to] + if self.warmup_ratio < 0 or self.warmup_ratio > 1: + raise ValueError("warmup_ratio must lie in range [0,1]") + elif self.warmup_ratio > 0 and self.warmup_steps > 0: + logger.info( + "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training" + ) + def __repr__(self): # We override the default repr to remove deprecated arguments from the repr. This method should be removed once # those deprecated arguments are removed form TrainingArguments. (TODO: v5) diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 9d6b9492b1..8215a0122a 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -94,8 +94,11 @@ class TFTrainingArguments(TrainingArguments): max_steps (:obj:`int`, `optional`, defaults to -1): If set to a positive number, the total number of training steps to perform. Overrides :obj:`num_train_epochs`. + warmup_ratio (:obj:`float`, `optional`, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`. warmup_steps (:obj:`int`, `optional`, defaults to 0): - Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. + Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of + :obj:`warmup_ratio`. logging_dir (:obj:`str`, `optional`): `TensorBoard `__ log directory. Will default to `runs/**CURRENT_DATETIME_HOSTNAME**`.