From 208df208bf2b258677ece94e86c7e6e08ad2cd41 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 6 Jul 2021 19:41:51 +0100 Subject: [PATCH] [Flax] Adapt examples to be able to use eval_steps and save_steps (#12543) * fix_torch_device_generate_test * remove @ * up * up * correct * upload Co-authored-by: Patrick von Platen --- examples/flax/language-modeling/README.md | 6 ++ .../flax/language-modeling/run_clm_flax.py | 69 ++++++++++--------- .../flax/language-modeling/run_mlm_flax.py | 68 +++++++++--------- .../flax/language-modeling/run_t5_mlm_flax.py | 59 ++++++++-------- 4 files changed, 107 insertions(+), 95 deletions(-) diff --git a/examples/flax/language-modeling/README.md b/examples/flax/language-modeling/README.md index 48b5613283..ad0b30cf41 100644 --- a/examples/flax/language-modeling/README.md +++ b/examples/flax/language-modeling/README.md @@ -141,6 +141,8 @@ Next we can run the example script to pretrain the model: --adam_beta1="0.9" \ --adam_beta2="0.98" \ --logging_steps="500" \ + --save_steps="2500" \ + --eval_steps="2500" \ --push_to_hub ``` @@ -234,6 +236,8 @@ Next we can run the example script to pretrain the model: --overwrite_output_dir \ --num_train_epochs="20" \ --logging_steps="500" \ + --save_steps="2500" \ + --eval_steps="2500" \ --push_to_hub ``` @@ -370,6 +374,8 @@ Next we can run the example script to pretrain the model: --overwrite_output_dir \ --num_train_epochs="10" \ --logging_steps="500" \ + --save_steps="2500" \ + --eval_steps="2500" \ --push_to_hub ``` diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index b63612bd93..1d84af80c4 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -587,45 +587,46 @@ def main(): train_metrics = [] - # ======================== Evaluating ============================== - eval_metrics = [] - eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) - eval_steps = len(eval_dataset) // eval_batch_size - for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): - # Model forward - batch = next(eval_loader) - metrics = p_eval_step(state.params, batch) - eval_metrics.append(metrics) + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + # ======================== Evaluating ============================== + eval_metrics = [] + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) + eval_steps = len(eval_dataset) // eval_batch_size + for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): + # Model forward + batch = next(eval_loader) + metrics = p_eval_step(state.params, batch) + eval_metrics.append(metrics) - # normalize eval metrics - eval_metrics = get_metrics(eval_metrics) + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + try: + eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) + except OverflowError: + eval_metrics["perplexity"] = float("inf") - try: - eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) - except OverflowError: - eval_metrics["perplexity"] = float("inf") + # Print metrics and update progress bar + desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" + epochs.write(desc) + epochs.desc = desc - # Print metrics and update progress bar - desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" - epochs.write(desc) - epochs.desc = desc + # Save metrics + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(train_dataset) // train_batch_size) + write_eval_metric(summary_writer, eval_metrics, cur_step) - # Save metrics - if has_tensorboard and jax.process_index() == 0: - cur_step = epoch * (len(train_dataset) // train_batch_size) - write_eval_metric(summary_writer, eval_metrics, cur_step) - - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(unreplicate(state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of epoch {epoch+1}", - ) + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(unreplicate(state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of step {cur_step}", + ) if __name__ == "__main__": diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 3bb74d1a06..4da1908dbe 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -621,43 +621,43 @@ if __name__ == "__main__": train_metrics = [] - # ======================== Evaluating ============================== - num_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + # ======================== Evaluating ============================== + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) - eval_metrics = [] - for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): - samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] - model_inputs = data_collator(samples, pad_to_multiple_of=16) + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples, pad_to_multiple_of=16) - # Model forward - model_inputs = shard(model_inputs.data) - metrics = p_eval_step(state.params, model_inputs) - eval_metrics.append(metrics) + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) - # normalize eval metrics - eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.sum, eval_metrics) - eval_normalizer = eval_metrics.pop("normalizer") - eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) + # normalize eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.sum, eval_metrics) + eval_normalizer = eval_metrics.pop("normalizer") + eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics) - # Update progress bar - epochs.desc = ( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" - ) + # Update progress bar + epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" - # Save metrics - if has_tensorboard and jax.process_index() == 0: - cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) - write_eval_metric(summary_writer, eval_metrics, cur_step) + # Save metrics + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) + write_eval_metric(summary_writer, eval_metrics, cur_step) - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained( - training_args.output_dir, - params=params, - push_to_hub=training_args.push_to_hub, - commit_message=f"Saving weights and logs of epoch {epoch+1}", - ) + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of step {cur_step}", + ) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index dc87f0093a..56c27c8752 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -737,36 +737,41 @@ if __name__ == "__main__": train_metrics = [] - # ======================== Evaluating ============================== - num_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + if cur_step % training_args.eval_steps == 0 and cur_step > 0: + # ======================== Evaluating ============================== + num_eval_samples = len(tokenized_datasets["validation"]) + eval_samples_idx = jnp.arange(num_eval_samples) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) - eval_metrics = [] - for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): - samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] - model_inputs = data_collator(samples) + eval_metrics = [] + for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): + samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx] + model_inputs = data_collator(samples) - # Model forward - model_inputs = shard(model_inputs.data) - metrics = p_eval_step(state.params, model_inputs) - eval_metrics.append(metrics) + # Model forward + model_inputs = shard(model_inputs.data) + metrics = p_eval_step(state.params, model_inputs) + eval_metrics.append(metrics) - # get eval metrics - eval_metrics = get_metrics(eval_metrics) - eval_metrics = jax.tree_map(jnp.mean, eval_metrics) + # get eval metrics + eval_metrics = get_metrics(eval_metrics) + eval_metrics = jax.tree_map(jnp.mean, eval_metrics) - # Update progress bar - epochs.write( - f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" - ) + # Update progress bar + epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})") - # Save metrics - if has_tensorboard and jax.process_index() == 0: - cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) - write_eval_metric(summary_writer, eval_metrics, cur_step) + # Save metrics + if has_tensorboard and jax.process_index() == 0: + cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size) + write_eval_metric(summary_writer, eval_metrics, cur_step) - # save checkpoint after each epoch and push checkpoint to the hub - if jax.process_index() == 0: - params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) - model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub) + if cur_step % training_args.save_steps == 0 and cur_step > 0: + # save checkpoint after each epoch and push checkpoint to the hub + if jax.process_index() == 0: + params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) + model.save_pretrained( + training_args.output_dir, + params=params, + push_to_hub=training_args.push_to_hub, + commit_message=f"Saving weights and logs of step {cur_step}", + )