Fix RESOURCE_EXHAUSTED error when dealing with large datasets in Flax example scripts (#18069)

* Fix RESOURCE_EXHAUSTED error for large datasets on Flax example scripts

* using np.permutation for creating batch_idx

* train_samples_idx -> training_samples_idx

* fix type hints
This commit is contained in:
Duong A. Nguyen
2022-07-11 20:59:08 +07:00
committed by GitHub
parent ac98a88fbc
commit 1e8140caad
5 changed files with 26 additions and 15 deletions

View File

@@ -326,7 +326,7 @@ class FlaxDataCollatorForLanguageModeling:
return inputs, labels
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
num_samples = len(samples_idx)
samples_to_remove = num_samples % batch_size
@@ -755,7 +755,8 @@ def main():
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"])
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
# Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
@@ -787,7 +788,8 @@ def main():
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
@@ -825,7 +827,8 @@ def main():
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
eval_samples_idx = jnp.arange(num_eval_samples)
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []