Introduce warmup_ratio training argument (#10229)
Introduce warmup_ratio training argument in both TrainingArguments and TFTrainingArguments classes (#6673)
This commit is contained in:
@@ -131,8 +131,11 @@ class TrainingArguments:
|
||||
lr_scheduler_type (:obj:`str` or :class:`~transformers.SchedulerType`, `optional`, defaults to :obj:`"linear"`):
|
||||
The scheduler type to use. See the documentation of :class:`~transformers.SchedulerType` for all possible
|
||||
values.
|
||||
warmup_ratio (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Ratio of total training steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
||||
warmup_steps (:obj:`int`, `optional`, defaults to 0):
|
||||
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`.
|
||||
Number of steps used for a linear warmup from 0 to :obj:`learning_rate`. Overrides any effect of
|
||||
:obj:`warmup_ratio`.
|
||||
logging_dir (:obj:`str`, `optional`):
|
||||
`TensorBoard <https://www.tensorflow.org/tensorboard>`__ log directory. Will default to
|
||||
`runs/**CURRENT_DATETIME_HOSTNAME**`.
|
||||
@@ -324,6 +327,9 @@ class TrainingArguments:
|
||||
default="linear",
|
||||
metadata={"help": "The scheduler type to use."},
|
||||
)
|
||||
warmup_ratio: float = field(
|
||||
default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
|
||||
)
|
||||
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
||||
|
||||
logging_dir: Optional[str] = field(default_factory=default_logdir, metadata={"help": "Tensorboard log dir."})
|
||||
@@ -495,6 +501,13 @@ class TrainingArguments:
|
||||
elif not isinstance(self.report_to, list):
|
||||
self.report_to = [self.report_to]
|
||||
|
||||
if self.warmup_ratio < 0 or self.warmup_ratio > 1:
|
||||
raise ValueError("warmup_ratio must lie in range [0,1]")
|
||||
elif self.warmup_ratio > 0 and self.warmup_steps > 0:
|
||||
logger.info(
|
||||
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||
|
||||
Reference in New Issue
Block a user