From 91fb62d01c77d6ff7c01287bb6457f83732d9d61 Mon Sep 17 00:00:00 2001 From: Yeb Havinga Date: Tue, 8 Mar 2022 12:18:38 +0100 Subject: [PATCH] Speedup training by using numpy instead of jnp for batch shuffling (#15963) Speedup training by using numpy instead of jnp for batch shuffling Co-authored-by: Yeb Havinga --- examples/flax/language-modeling/run_t5_mlm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index e0ea0fa3fb..83ef2dbc30 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -810,7 +810,7 @@ def main(): # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) - train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples)) + train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step