Make data shuffling in run_clm_flax.py respect global seed (#13410)
* use jax and jnp instead of numpy in data_loader * return batches as np.ndarray
This commit is contained in:
committed by
GitHub
parent
546a91abe9
commit
2a606f9974
@@ -253,9 +253,9 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|||||||
steps_per_epoch = len(dataset) // batch_size
|
steps_per_epoch = len(dataset) // batch_size
|
||||||
|
|
||||||
if shuffle:
|
if shuffle:
|
||||||
batch_idx = np.random.permutation(len(dataset))
|
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||||
else:
|
else:
|
||||||
batch_idx = np.arange(len(dataset))
|
batch_idx = jnp.arange(len(dataset))
|
||||||
|
|
||||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||||
|
|||||||
Reference in New Issue
Block a user