diff --git a/examples/flax/language-modeling/requirements.txt b/examples/flax/language-modeling/requirements.txt index 6ab626a17f..4ef5d1afd0 100644 --- a/examples/flax/language-modeling/requirements.txt +++ b/examples/flax/language-modeling/requirements.txt @@ -2,4 +2,4 @@ datasets >= 1.1.3 jax>=0.2.8 jaxlib>=0.1.59 flax>=0.3.4 -optax>=0.0.8 +optax>=0.0.9 diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 1d84af80c4..d20de1b1d6 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -489,17 +489,24 @@ def main(): return traverse_util.unflatten_dict(flat_mask) # create adam optimizer - adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=training_args.adam_beta1, - b2=training_args.adam_beta2, - eps=training_args.adam_epsilon, - weight_decay=training_args.weight_decay, - mask=decay_mask_fn, - ) + if training_args.adafactor: + # We use the default parameters here to initialize adafactor, + # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 + optimizer = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + ) + else: + optimizer = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) # Setup train state - state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) + state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 4da1908dbe..1ae9f05230 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -513,17 +513,24 @@ if __name__ == "__main__": return traverse_util.unflatten_dict(flat_mask) # create adam optimizer - adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=training_args.adam_beta1, - b2=training_args.adam_beta2, - eps=1e-8, - weight_decay=training_args.weight_decay, - mask=decay_mask_fn, - ) + if training_args.adafactor: + # We use the default parameters here to initialize adafactor, + # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 + optimizer = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + ) + else: + optimizer = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + eps=training_args.adam_epsilon, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) # Setup train state - state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) + state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 56c27c8752..09d6753a48 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -635,16 +635,23 @@ if __name__ == "__main__": return traverse_util.unflatten_dict(flat_mask) # create adam optimizer - adamw = optax.adamw( - learning_rate=linear_decay_lr_schedule_fn, - b1=training_args.adam_beta1, - b2=training_args.adam_beta2, - weight_decay=training_args.weight_decay, - mask=decay_mask_fn, - ) + if training_args.adafactor: + # We use the default parameters here to initialize adafactor, + # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 + optimizer = optax.adafactor( + learning_rate=linear_decay_lr_schedule_fn, + ) + else: + optimizer = optax.adamw( + learning_rate=linear_decay_lr_schedule_fn, + b1=training_args.adam_beta1, + b2=training_args.adam_beta2, + weight_decay=training_args.weight_decay, + mask=decay_mask_fn, + ) # Setup train state - state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) + state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng):