From 1e8140caad248ba7ea8797f0575f7931bd4b12a6 Mon Sep 17 00:00:00 2001 From: "Duong A. Nguyen" <38061659+duongna21@users.noreply.github.com> Date: Mon, 11 Jul 2022 20:59:08 +0700 Subject: [PATCH] Fix RESOURCE_EXHAUSTED error when dealing with large datasets in Flax example scripts (#18069) * Fix RESOURCE_EXHAUSTED error for large datasets on Flax example scripts * using np.permutation for creating batch_idx * train_samples_idx -> training_samples_idx * fix type hints --- examples/flax/language-modeling/run_mlm_flax.py | 11 +++++++---- examples/flax/language-modeling/run_t5_mlm_flax.py | 9 ++++++--- .../dataset-streaming/run_mlm_flax_stream.py | 5 +++-- .../wav2vec2/run_wav2vec2_pretrain_flax.py | 8 +++++--- .../research_projects/performer/run_mlm_performer.py | 8 +++++--- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 3538ba2683..831cceef2b 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -326,7 +326,7 @@ class FlaxDataCollatorForLanguageModeling: return inputs, labels -def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: num_samples = len(samples_idx) samples_to_remove = num_samples % batch_size @@ -755,7 +755,8 @@ def main(): # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) - train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) + # Avoid using jax.numpy here in case of TPU training + train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step @@ -787,7 +788,8 @@ def main(): if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] @@ -825,7 +827,8 @@ def main(): # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index 48a58b60c0..892db76924 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -459,7 +459,7 @@ class FlaxDataCollatorForT5MLM: return is_noise[:orig_length] -def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: num_samples = len(samples_idx) samples_to_remove = num_samples % batch_size @@ -871,6 +871,7 @@ def main(): # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) + # Avoid using jax.numpy here in case of TPU training train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) @@ -908,7 +909,8 @@ def main(): if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] @@ -944,7 +946,8 @@ def main(): # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] diff --git a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py index f0f3e873d8..fadcec09cb 100755 --- a/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py +++ b/examples/research_projects/jax-projects/dataset-streaming/run_mlm_flax_stream.py @@ -264,7 +264,7 @@ class FlaxDataCollatorForLanguageModeling: return inputs, labels -def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: num_samples = len(samples_idx) samples_to_remove = num_samples % batch_size @@ -592,7 +592,8 @@ if __name__ == "__main__": # ======================== Evaluating ============================== if step % training_args.eval_steps == 0 and step > 0: - eval_samples_idx = jnp.arange(data_args.num_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(data_args.num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)): diff --git a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py index b0600d978b..457c58d44f 100755 --- a/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py +++ b/examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py @@ -237,7 +237,7 @@ def write_eval_metric(summary_writer, eval_metrics, step): summary_writer.scalar(f"eval_{metric_name}", value, step) -def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: num_samples = len(samples_idx) samples_to_remove = num_samples % batch_size @@ -541,7 +541,8 @@ def main(): # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(vectorized_datasets["train"]) - train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) + # Avoid using jax.numpy here in case of TPU training + train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step @@ -574,7 +575,8 @@ def main(): # ======================== Evaluating ============================== num_eval_samples = len(vectorized_datasets["validation"]) - eval_samples_idx = jnp.arange(num_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] diff --git a/examples/research_projects/performer/run_mlm_performer.py b/examples/research_projects/performer/run_mlm_performer.py index be20342d3a..8e8fe91765 100644 --- a/examples/research_projects/performer/run_mlm_performer.py +++ b/examples/research_projects/performer/run_mlm_performer.py @@ -433,7 +433,7 @@ def eval_step(params, batch): return compute_metrics(logits, targets, token_mask) -def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: nb_samples = len(samples_idx) samples_to_remove = nb_samples % batch_size @@ -639,7 +639,8 @@ if __name__ == "__main__": # Generate an epoch by shuffling sampling indices from the train dataset nb_training_samples = len(tokenized_datasets["train"]) - training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples)) + # Avoid using jax.numpy here in case of TPU training + training_samples_idx = np.random.permutation(np.arange(nb_training_samples)) training_batch_idx = generate_batch_splits(training_samples_idx, batch_size) # Gather the indexes for creating the batch and do a training step @@ -658,7 +659,8 @@ if __name__ == "__main__": # ======================== Evaluating ============================== nb_eval_samples = len(tokenized_datasets["validation"]) - eval_samples_idx = jnp.arange(nb_eval_samples) + # Avoid using jax.numpy here in case of TPU training + eval_samples_idx = np.arange(nb_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = []