[examples/flax] add adafactor optimizer (#12544)

* add adafactor

* Update examples/flax/language-modeling/run_mlm_flax.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2021-07-07 11:50:30 +05:30
committed by GitHub
parent 208df208bf
commit 2d42915abe
4 changed files with 48 additions and 27 deletions

View File

@@ -2,4 +2,4 @@ datasets >= 1.1.3
jax>=0.2.8 jax>=0.2.8
jaxlib>=0.1.59 jaxlib>=0.1.59
flax>=0.3.4 flax>=0.3.4
optax>=0.0.8 optax>=0.0.9

View File

@@ -489,7 +489,14 @@ def main():
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
adamw = optax.adamw( if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn, learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1, b1=training_args.adam_beta1,
b2=training_args.adam_beta2, b2=training_args.adam_beta2,
@@ -499,7 +506,7 @@ def main():
) )
# Setup train state # Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
def loss_fn(logits, labels): def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :] shift_logits = logits[..., :-1, :]

View File

@@ -513,17 +513,24 @@ if __name__ == "__main__":
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
adamw = optax.adamw( if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn, learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1, b1=training_args.adam_beta1,
b2=training_args.adam_beta2, b2=training_args.adam_beta2,
eps=1e-8, eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay, weight_decay=training_args.weight_decay,
mask=decay_mask_fn, mask=decay_mask_fn,
) )
# Setup train state # Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn # Define gradient update step fn
def train_step(state, batch, dropout_rng): def train_step(state, batch, dropout_rng):

View File

@@ -635,7 +635,14 @@ if __name__ == "__main__":
return traverse_util.unflatten_dict(flat_mask) return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer # create adam optimizer
adamw = optax.adamw( if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn, learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1, b1=training_args.adam_beta1,
b2=training_args.adam_beta2, b2=training_args.adam_beta2,
@@ -644,7 +651,7 @@ if __name__ == "__main__":
) )
# Setup train state # Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn # Define gradient update step fn
def train_step(state, batch, dropout_rng): def train_step(state, batch, dropout_rng):