[examples/flax] add adafactor optimizer (#12544)
* add adafactor * Update examples/flax/language-modeling/run_mlm_flax.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -2,4 +2,4 @@ datasets >= 1.1.3
|
|||||||
jax>=0.2.8
|
jax>=0.2.8
|
||||||
jaxlib>=0.1.59
|
jaxlib>=0.1.59
|
||||||
flax>=0.3.4
|
flax>=0.3.4
|
||||||
optax>=0.0.8
|
optax>=0.0.9
|
||||||
|
|||||||
@@ -489,17 +489,24 @@ def main():
|
|||||||
return traverse_util.unflatten_dict(flat_mask)
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
adamw = optax.adamw(
|
if training_args.adafactor:
|
||||||
learning_rate=linear_decay_lr_schedule_fn,
|
# We use the default parameters here to initialize adafactor,
|
||||||
b1=training_args.adam_beta1,
|
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
||||||
b2=training_args.adam_beta2,
|
optimizer = optax.adafactor(
|
||||||
eps=training_args.adam_epsilon,
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
weight_decay=training_args.weight_decay,
|
)
|
||||||
mask=decay_mask_fn,
|
else:
|
||||||
)
|
optimizer = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=training_args.adam_beta1,
|
||||||
|
b2=training_args.adam_beta2,
|
||||||
|
eps=training_args.adam_epsilon,
|
||||||
|
weight_decay=training_args.weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
||||||
|
|
||||||
def loss_fn(logits, labels):
|
def loss_fn(logits, labels):
|
||||||
shift_logits = logits[..., :-1, :]
|
shift_logits = logits[..., :-1, :]
|
||||||
|
|||||||
@@ -513,17 +513,24 @@ if __name__ == "__main__":
|
|||||||
return traverse_util.unflatten_dict(flat_mask)
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
adamw = optax.adamw(
|
if training_args.adafactor:
|
||||||
learning_rate=linear_decay_lr_schedule_fn,
|
# We use the default parameters here to initialize adafactor,
|
||||||
b1=training_args.adam_beta1,
|
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
||||||
b2=training_args.adam_beta2,
|
optimizer = optax.adafactor(
|
||||||
eps=1e-8,
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
weight_decay=training_args.weight_decay,
|
)
|
||||||
mask=decay_mask_fn,
|
else:
|
||||||
)
|
optimizer = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=training_args.adam_beta1,
|
||||||
|
b2=training_args.adam_beta2,
|
||||||
|
eps=training_args.adam_epsilon,
|
||||||
|
weight_decay=training_args.weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
||||||
|
|
||||||
# Define gradient update step fn
|
# Define gradient update step fn
|
||||||
def train_step(state, batch, dropout_rng):
|
def train_step(state, batch, dropout_rng):
|
||||||
|
|||||||
@@ -635,16 +635,23 @@ if __name__ == "__main__":
|
|||||||
return traverse_util.unflatten_dict(flat_mask)
|
return traverse_util.unflatten_dict(flat_mask)
|
||||||
|
|
||||||
# create adam optimizer
|
# create adam optimizer
|
||||||
adamw = optax.adamw(
|
if training_args.adafactor:
|
||||||
learning_rate=linear_decay_lr_schedule_fn,
|
# We use the default parameters here to initialize adafactor,
|
||||||
b1=training_args.adam_beta1,
|
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
||||||
b2=training_args.adam_beta2,
|
optimizer = optax.adafactor(
|
||||||
weight_decay=training_args.weight_decay,
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
mask=decay_mask_fn,
|
)
|
||||||
)
|
else:
|
||||||
|
optimizer = optax.adamw(
|
||||||
|
learning_rate=linear_decay_lr_schedule_fn,
|
||||||
|
b1=training_args.adam_beta1,
|
||||||
|
b2=training_args.adam_beta2,
|
||||||
|
weight_decay=training_args.weight_decay,
|
||||||
|
mask=decay_mask_fn,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup train state
|
# Setup train state
|
||||||
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
||||||
|
|
||||||
# Define gradient update step fn
|
# Define gradient update step fn
|
||||||
def train_step(state, batch, dropout_rng):
|
def train_step(state, batch, dropout_rng):
|
||||||
|
|||||||
Reference in New Issue
Block a user