From 2a606f9974feb0f7578e6a638c7e5b548523ecb4 Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Tue, 14 Dec 2021 11:04:43 +0100 Subject: [PATCH] Make data shuffling in `run_clm_flax.py` respect global seed (#13410) * use jax and jnp instead of numpy in data_loader * return batches as np.ndarray --- examples/flax/language-modeling/run_clm_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 46cb16a921..7746400d09 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -253,9 +253,9 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf steps_per_epoch = len(dataset) // batch_size if shuffle: - batch_idx = np.random.permutation(len(dataset)) + batch_idx = jax.random.permutation(rng, len(dataset)) else: - batch_idx = np.arange(len(dataset)) + batch_idx = jnp.arange(len(dataset)) batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))