Speedup training by using numpy instead of jnp for batch shuffling (#15963)
Speedup training by using numpy instead of jnp for batch shuffling Co-authored-by: Yeb Havinga <y.t.havinga@mgrid.net>
This commit is contained in:
@@ -810,7 +810,7 @@ def main():
|
|||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
num_train_samples = len(tokenized_datasets["train"])
|
num_train_samples = len(tokenized_datasets["train"])
|
||||||
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
||||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||||
|
|
||||||
# Gather the indexes for creating the batch and do a training step
|
# Gather the indexes for creating the batch and do a training step
|
||||||
|
|||||||
Reference in New Issue
Block a user