Fix RESOURCE_EXHAUSTED error when dealing with large datasets in Flax example scripts (#18069)
* Fix RESOURCE_EXHAUSTED error for large datasets on Flax example scripts * using np.permutation for creating batch_idx * train_samples_idx -> training_samples_idx * fix type hints
This commit is contained in:
@@ -433,7 +433,7 @@ def eval_step(params, batch):
|
||||
return compute_metrics(logits, targets, token_mask)
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||
nb_samples = len(samples_idx)
|
||||
samples_to_remove = nb_samples % batch_size
|
||||
|
||||
@@ -639,7 +639,8 @@ if __name__ == "__main__":
|
||||
|
||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||
nb_training_samples = len(tokenized_datasets["train"])
|
||||
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
training_samples_idx = np.random.permutation(np.arange(nb_training_samples))
|
||||
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
||||
|
||||
# Gather the indexes for creating the batch and do a training step
|
||||
@@ -658,7 +659,8 @@ if __name__ == "__main__":
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
nb_eval_samples = len(tokenized_datasets["validation"])
|
||||
eval_samples_idx = jnp.arange(nb_eval_samples)
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(nb_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
|
||||
eval_metrics = []
|
||||
|
||||
Reference in New Issue
Block a user