From b21993b3625354a2e6255be09b0a9acec068ec11 Mon Sep 17 00:00:00 2001 From: Gong Linyuan Date: Mon, 27 Jul 2020 17:31:37 +0800 Subject: [PATCH] Allow to set Adam beta1, beta2 in TrainingArgs (#5592) * Add Adam beta1, beta2 to trainier * Make style consistent --- src/transformers/optimization_tf.py | 14 +++++++++++--- src/transformers/trainer.py | 7 ++++++- src/transformers/trainer_tf.py | 2 ++ src/transformers/training_args.py | 2 ++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/transformers/optimization_tf.py b/src/transformers/optimization_tf.py index 21c3557ae8..7043bcdf1f 100644 --- a/src/transformers/optimization_tf.py +++ b/src/transformers/optimization_tf.py @@ -84,6 +84,8 @@ def create_optimizer( num_train_steps: int, num_warmup_steps: int, min_lr_ratio: float = 0.0, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, adam_epsilon: float = 1e-8, weight_decay_rate: float = 0.0, include_in_weight_decay: Optional[List[str]] = None, @@ -100,6 +102,10 @@ def create_optimizer( The number of warmup steps. min_lr_ratio (:obj:`float`, `optional`, defaults to 0): The final learning rate at the end of the linear decay will be :obj:`init_lr * min_lr_ratio`. + adam_beta1 (:obj:`float`, `optional`, defaults to 0.9): + The beta1 to use in Adam. + adam_beta2 (:obj:`float`, `optional`, defaults to 0.999): + The beta2 to use in Adam. adam_epsilon (:obj:`float`, `optional`, defaults to 1e-8): The epsilon to use in Adam. weight_decay_rate (:obj:`float`, `optional`, defaults to 0): @@ -122,14 +128,16 @@ def create_optimizer( optimizer = AdamWeightDecay( learning_rate=lr_schedule, weight_decay_rate=weight_decay_rate, - beta_1=0.9, - beta_2=0.999, + beta_1=adam_beta1, + beta_2=adam_beta2, epsilon=adam_epsilon, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"], include_in_weight_decay=include_in_weight_decay, ) else: - optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, epsilon=adam_epsilon) + optimizer = tf.keras.optimizers.Adam( + learning_rate=lr_schedule, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon + ) # We return the optimizer and the LR scheduler in order to better track the # evolution of the LR independently of the optimizer. return optimizer, lr_schedule diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7a6db77887..06d467a354 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -343,7 +343,12 @@ class Trainer: "weight_decay": 0.0, }, ] - optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) + optimizer = AdamW( + optimizer_grouped_parameters, + lr=self.args.learning_rate, + betas=(self.args.adam_beta1, self.args.adam_beta2), + eps=self.args.adam_epsilon, + ) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps ) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index 8e56251b27..bd3bf1e925 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -171,6 +171,8 @@ class TFTrainer: self.args.learning_rate, num_training_steps, self.args.warmup_steps, + adam_beta1=self.args.adam_beta1, + adam_beta2=self.args.adam_beta2, adam_epsilon=self.args.adam_epsilon, weight_decay_rate=self.args.weight_decay, ) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 90fc8f266d..e6506d9763 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -160,6 +160,8 @@ class TrainingArguments: learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for Adam optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for Adam optimizer"}) adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for Adam optimizer."}) max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})