Improvements to Flax finetuning script (#11727)
* Add Cloud details to README * Flax script and readme updates * Some simplifications of Flax script
This commit is contained in:
@@ -59,20 +59,20 @@ On the task other than MRPC and WNLI we train for 3 these epochs because this is
|
|||||||
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
|
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
|
||||||
are undertrained and we could get better results when training longer.
|
are undertrained and we could get better results when training longer.
|
||||||
|
|
||||||
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1zKL_xn32HwbxkFMxB3ftca-soTHAuBFgIhYhOhCnZ4E/edit?usp=sharing).
|
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 2, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1wtcjX_fJLjYs6kXkoiej2qGjrl9ByfNhPulPAz71Ky4/edit?usp=sharing).
|
||||||
|
|
||||||
|
|
||||||
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|
||||||
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
|
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
|
||||||
| CoLA | Matthew's corr | 59.57 | 58.04 | 1.81 | [tfhub.dev](https://tensorboard.dev/experiment/f4OvQpWtRq6CvddpxGBd0A/) |
|
| CoLA | Matthew's corr | 59.29 | 56.25 | 2.18 | [tfhub.dev](https://tensorboard.dev/experiment/tNBiYyvsRv69ZlXRI7x0pQ/) |
|
||||||
| SST-2 | Accuracy | 92.43 | 91.79 | 0.59 | [tfhub.dev](https://tensorboard.dev/experiment/BYFwa49MRTaLIn93DgAEtA/) |
|
| SST-2 | Accuracy | 91.97 | 91.79 | 0.42 | [tfhub.dev](https://tensorboard.dev/experiment/wQto9nBwQHOINUxjKAAblQ/) |
|
||||||
| MRPC | F1/Accuracy | 89.50/84.8 | 88.70/84.02 | 0.56/0.48 | [tfhub.dev](https://tensorboard.dev/experiment/9ZWH5xwXRS6zEEUE4RaBhQ/) |
|
| MRPC | F1/Accuracy | 90.39/86.03 | 89.70/85.20 | 0.68/0.91 | [tfhub.dev](https://tensorboard.dev/experiment/Q40mkOtDSYymFRfo4jKsgQ/) |
|
||||||
| STS-B | Pearson/Spearman corr. | 90.00/88.71 | 89.09/88.61 | 0.51/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/mUlI5B9QQ0WGEJip7p3Tng/) |
|
| STS-B | Pearson/Spearman corr. | 89.19/88.91 | 89.40/89.09 | 0.18/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/a2bfeAy6SveV0X0FjwxMXQ/) |
|
||||||
| QQP | Accuracy/F1 | 90.88/87.64 | 90.75/87.53 | 0.11/0.13 | [tfhub.dev](https://tensorboard.dev/experiment/pO6h75L3SvSXSWRcgljXKA/) |
|
| QQP | Accuracy/F1 | 91.02/87.90 | 90.96/87.75 | 0.08/0.14 | [tfhub.dev](https://tensorboard.dev/experiment/kL2vGgoQQeyTVGetehbCpg/) |
|
||||||
| MNLI | Matched acc. | 84.06 | 83.88 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/LKwaOH18RMuo7nJkESrpKg/) |
|
| MNLI | Matched acc. | 83.82 | 83.65 | 0.28 | [tfhub.dev](https://tensorboard.dev/experiment/nck6178dTpmTOPm7862urA/) |
|
||||||
| QNLI | Accuracy | 91.01 | 90.86 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/qesXxNcaQhmKxPmbw1sOoA/) |
|
| QNLI | Accuracy | 90.81 | 90.88 | 0.18 | [tfhub.dev](https://tensorboard.dev/experiment/44slZTLKQtqGhWs1Rhedcg/) |
|
||||||
| RTE | Accuracy | 66.80 | 65.27 | 1.07 | [tfhub.dev](https://tensorboard.dev/experiment/Z84xC0r6RjyzT4SLqiAbzQ/) |
|
| RTE | Accuracy | 69.31 | 66.79 | 1.88 | [tfhub.dev](https://tensorboard.dev/experiment/g0yvpEXKSAytDMvP8TP8Og/) |
|
||||||
| WNLI | Accuracy | 39.44 | 32.96 | 5.85 | [tfhub.dev](https://tensorboard.dev/experiment/gV73w9v0RIKrqVw32PZbAQ/) |
|
| WNLI | Accuracy | 56.34 | 36.62 | 12.48 | [tfhub.dev](https://tensorboard.dev/experiment/7DfXdlDnTWWKBEx4pXForA/) |
|
||||||
|
|
||||||
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
|
||||||
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ def parse_args():
|
|||||||
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
||||||
)
|
)
|
||||||
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
||||||
parser.add_argument("--seed", type=int, default=2, help="A seed for reproducible training.")
|
parser.add_argument("--seed", type=int, default=5, help="A seed for reproducible training.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Sanity checks
|
# Sanity checks
|
||||||
@@ -148,6 +148,7 @@ def create_train_state(
|
|||||||
learning_rate_fn: Callable[[int], float],
|
learning_rate_fn: Callable[[int], float],
|
||||||
is_regression: bool,
|
is_regression: bool,
|
||||||
num_labels: int,
|
num_labels: int,
|
||||||
|
weight_decay: float,
|
||||||
) -> train_state.TrainState:
|
) -> train_state.TrainState:
|
||||||
"""Create initial training state."""
|
"""Create initial training state."""
|
||||||
|
|
||||||
@@ -166,8 +167,8 @@ def create_train_state(
|
|||||||
loss_fn: Callable = struct.field(pytree_node=False)
|
loss_fn: Callable = struct.field(pytree_node=False)
|
||||||
|
|
||||||
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers.
|
# Creates a multi-optimizer consisting of two "Adam with weight decay" optimizers.
|
||||||
def adamw(weight_decay):
|
def adamw(decay):
|
||||||
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay)
|
return optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=decay)
|
||||||
|
|
||||||
def traverse(fn):
|
def traverse(fn):
|
||||||
def mask(data):
|
def mask(data):
|
||||||
@@ -183,7 +184,7 @@ def create_train_state(
|
|||||||
|
|
||||||
tx = optax.chain(
|
tx = optax.chain(
|
||||||
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
|
optax.masked(adamw(0.0), mask=traverse(lambda path, _: decay_path(path))),
|
||||||
optax.masked(adamw(0.01), mask=traverse(lambda path, _: not decay_path(path))),
|
optax.masked(adamw(weight_decay), mask=traverse(lambda path, _: not decay_path(path))),
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_regression:
|
if is_regression:
|
||||||
@@ -414,7 +415,9 @@ def main():
|
|||||||
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
|
len(train_dataset), train_batch_size, args.num_train_epochs, args.num_warmup_steps, args.learning_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
state = create_train_state(model, learning_rate_fn, is_regression, num_labels=num_labels)
|
state = create_train_state(
|
||||||
|
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
# define step functions
|
# define step functions
|
||||||
def train_step(
|
def train_step(
|
||||||
@@ -426,10 +429,10 @@ def main():
|
|||||||
def loss_fn(params):
|
def loss_fn(params):
|
||||||
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
||||||
loss = state.loss_fn(logits, targets)
|
loss = state.loss_fn(logits, targets)
|
||||||
return loss, logits
|
return loss
|
||||||
|
|
||||||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
grad_fn = jax.value_and_grad(loss_fn)
|
||||||
(loss, logits), grad = grad_fn(state.params)
|
loss, grad = grad_fn(state.params)
|
||||||
grad = jax.lax.pmean(grad, "batch")
|
grad = jax.lax.pmean(grad, "batch")
|
||||||
new_state = state.apply_gradients(grads=grad)
|
new_state = state.apply_gradients(grads=grad)
|
||||||
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
|
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
|
||||||
@@ -460,10 +463,11 @@ def main():
|
|||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
train_metrics = []
|
||||||
rng, input_rng, dropout_rng = jax.random.split(rng, 3)
|
rng, input_rng = jax.random.split(rng)
|
||||||
|
|
||||||
# train
|
# train
|
||||||
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
|
for batch in glue_train_data_collator(input_rng, train_dataset, train_batch_size):
|
||||||
|
rng, dropout_rng = jax.random.split(rng)
|
||||||
dropout_rngs = shard_prng_key(dropout_rng)
|
dropout_rngs = shard_prng_key(dropout_rng)
|
||||||
state, metrics = p_train_step(state, batch, dropout_rngs)
|
state, metrics = p_train_step(state, batch, dropout_rngs)
|
||||||
train_metrics.append(metrics)
|
train_metrics.append(metrics)
|
||||||
@@ -471,7 +475,6 @@ def main():
|
|||||||
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
logger.info(f" Done! Training metrics: {unreplicate(metrics)}")
|
||||||
|
|
||||||
logger.info(" Evaluating...")
|
logger.info(" Evaluating...")
|
||||||
rng, input_rng = jax.random.split(rng)
|
|
||||||
|
|
||||||
# evaluate
|
# evaluate
|
||||||
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
|
for batch in glue_eval_data_collator(eval_dataset, eval_batch_size):
|
||||||
@@ -484,20 +487,14 @@ def main():
|
|||||||
|
|
||||||
# make sure leftover batch is evaluated on one device
|
# make sure leftover batch is evaluated on one device
|
||||||
if num_leftover_samples > 0 and jax.process_index() == 0:
|
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||||
# put weights on single device
|
|
||||||
state = unreplicate(state)
|
|
||||||
|
|
||||||
# take leftover samples
|
# take leftover samples
|
||||||
batch = eval_dataset[-num_leftover_samples:]
|
batch = eval_dataset[-num_leftover_samples:]
|
||||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
batch = {k: jnp.array(v) for k, v in batch.items()}
|
||||||
|
|
||||||
labels = batch.pop("labels")
|
labels = batch.pop("labels")
|
||||||
predictions = eval_step(state, batch)
|
predictions = eval_step(unreplicate(state), batch)
|
||||||
metric.add_batch(predictions=predictions, references=labels)
|
metric.add_batch(predictions=predictions, references=labels)
|
||||||
|
|
||||||
# make sure weights are replicated on each device
|
|
||||||
state = replicate(state)
|
|
||||||
|
|
||||||
eval_metric = metric.compute()
|
eval_metric = metric.compute()
|
||||||
logger.info(f" Done! Eval metrics: {eval_metric}")
|
logger.info(f" Done! Eval metrics: {eval_metric}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user