Introduce warmup_ratio training argument (#10229)
Introduce warmup_ratio training argument in both TrainingArguments and TFTrainingArguments classes (#6673)
This commit is contained in:
@@ -615,10 +615,16 @@ class Trainer:
|
|||||||
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
|
||||||
if self.lr_scheduler is None:
|
if self.lr_scheduler is None:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.warmup_steps
|
||||||
|
if self.args.warmup_steps > 0
|
||||||
|
else math.ceil(num_training_steps * self.args.warmup_ratio)
|
||||||
|
)
|
||||||
|
|
||||||
self.lr_scheduler = get_scheduler(
|
self.lr_scheduler = get_scheduler(
|
||||||
self.args.lr_scheduler_type,
|
self.args.lr_scheduler_type,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
num_warmup_steps=self.args.warmup_steps,
|
num_warmup_steps=warmup_steps,
|
||||||
num_training_steps=num_training_steps,
|
num_training_steps=num_training_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -218,10 +218,16 @@ class TFTrainer:
|
|||||||
TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
|
TFTrainer's init through :obj:`optimizers`, or subclass and override this method.
|
||||||
"""
|
"""
|
||||||
if not self.optimizer and not self.lr_scheduler:
|
if not self.optimizer and not self.lr_scheduler:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.warmup_steps
|
||||||
|
if self.args.warmup_steps > 0
|
||||||
|
else math.ceil(num_training_steps * self.args.warmup_ratio)
|
||||||
|
)
|
||||||
|
|
||||||
self.optimizer, self.lr_scheduler = create_optimizer(
|
self.optimizer, self.lr_scheduler = create_optimizer(
|
||||||
self.args.learning_rate,
|
self.args.learning_rate,
|
||||||
num_training_steps,
|
num_training_steps,
|
||||||
self.args.warmup_steps,
|
warmup_steps,
|
||||||
adam_beta1=self.args.adam_beta1,
|
adam_beta1=self.args.adam_beta1,
|
||||||
adam_beta2=self.args.adam_beta2,
|
adam_beta2=self.args.adam_beta2,
|
||||||
adam_epsilon=self.args.adam_epsilon,
|
adam_epsilon=self.args.adam_epsilon,
|
||||||
|
|||||||
@@ -131,8 +131,11 @@ class TrainingArguments:
|
|||||||
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
||||||
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
||||||
values.
|
values.
|
||||||
|
warmup_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||||
|
Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
||||||
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
|
||||||
|
:obj:`warmup_ratio`.
|
||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||||
@@ -324,6 +327,9 @@ class TrainingArguments:
|
|||||||
default="linear",
|
default="linear",
|
||||||
metadata={"help": "The scheduler type to use."},
|
metadata={"help": "The scheduler type to use."},
|
||||||
)
|
)
|
||||||
|
warmup_ratio: float = field(
|
||||||
|
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
|
||||||
|
)
|
||||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||||
|
|
||||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||||
@@ -495,6 +501,13 @@ class TrainingArguments:
|
|||||||
elif not isinstance(self.report_to, list):
|
elif not isinstance(self.report_to, list):
|
||||||
self.report_to = [self.report_to]
|
self.report_to = [self.report_to]
|
||||||
|
|
||||||
|
if self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
||||||
|
raise ValueError("warmup_ratio must lie in range [0,1]")
|
||||||
|
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
|
||||||
|
logger.info(
|
||||||
|
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||||
|
|||||||
@@ -94,8 +94,11 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
max_steps (:obj:`int`, `optional`, defaults to -1):
|
max_steps (:obj:`int`, `optional`, defaults to -1):
|
||||||
If set to a positive number, the total number of training steps to perform. Overrides
|
If set to a positive number, the total number of training steps to perform. Overrides
|
||||||
:obj:`num_train_epochs`.
|
:obj:`num_train_epochs`.
|
||||||
|
warmup_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||||
|
Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
||||||
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
||||||
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
|
||||||
|
:obj:`warmup_ratio`.
|
||||||
logging_dir (:obj:`str`, `optional`):
|
logging_dir (:obj:`str`, `optional`):
|
||||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||||
|
|||||||
Reference in New Issue
Block a user