Fix RESOURCE_EXHAUSTED error when dealing with large datasets in Flax example scripts (#18069)
* Fix RESOURCE_EXHAUSTED error for large datasets on Flax example scripts * using np.permutation for creating batch_idx * train_samples_idx -> training_samples_idx * fix type hints
This commit is contained in:
@@ -326,7 +326,7 @@ class FlaxDataCollatorForLanguageModeling:
|
|||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
|
|
||||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||||
num_samples = len(samples_idx)
|
num_samples = len(samples_idx)
|
||||||
samples_to_remove = num_samples % batch_size
|
samples_to_remove = num_samples % batch_size
|
||||||
|
|
||||||
@@ -755,7 +755,8 @@ def main():
|
|||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
num_train_samples = len(tokenized_datasets["train"])
|
num_train_samples = len(tokenized_datasets["train"])
|
||||||
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
||||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||||
|
|
||||||
# Gather the indexes for creating the batch and do a training step
|
# Gather the indexes for creating the batch and do a training step
|
||||||
@@ -787,7 +788,8 @@ def main():
|
|||||||
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
@@ -825,7 +827,8 @@ def main():
|
|||||||
# Eval after training
|
# Eval after training
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
|
|||||||
@@ -459,7 +459,7 @@ class FlaxDataCollatorForT5MLM:
|
|||||||
return is_noise[:orig_length]
|
return is_noise[:orig_length]
|
||||||
|
|
||||||
|
|
||||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||||
num_samples = len(samples_idx)
|
num_samples = len(samples_idx)
|
||||||
samples_to_remove = num_samples % batch_size
|
samples_to_remove = num_samples % batch_size
|
||||||
|
|
||||||
@@ -871,6 +871,7 @@ def main():
|
|||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
num_train_samples = len(tokenized_datasets["train"])
|
num_train_samples = len(tokenized_datasets["train"])
|
||||||
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
||||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||||
|
|
||||||
@@ -908,7 +909,8 @@ def main():
|
|||||||
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
@@ -944,7 +946,8 @@ def main():
|
|||||||
# Eval after training
|
# Eval after training
|
||||||
if training_args.do_eval:
|
if training_args.do_eval:
|
||||||
num_eval_samples = len(tokenized_datasets["validation"])
|
num_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ class FlaxDataCollatorForLanguageModeling:
|
|||||||
return inputs, labels
|
return inputs, labels
|
||||||
|
|
||||||
|
|
||||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||||
num_samples = len(samples_idx)
|
num_samples = len(samples_idx)
|
||||||
samples_to_remove = num_samples % batch_size
|
samples_to_remove = num_samples % batch_size
|
||||||
|
|
||||||
@@ -592,7 +592,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
if step % training_args.eval_steps == 0 and step > 0:
|
if step % training_args.eval_steps == 0 and step > 0:
|
||||||
eval_samples_idx = jnp.arange(data_args.num_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(data_args.num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|||||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||||
|
|
||||||
|
|
||||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||||
num_samples = len(samples_idx)
|
num_samples = len(samples_idx)
|
||||||
samples_to_remove = num_samples % batch_size
|
samples_to_remove = num_samples % batch_size
|
||||||
|
|
||||||
@@ -541,7 +541,8 @@ def main():
|
|||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
num_train_samples = len(vectorized_datasets["train"])
|
num_train_samples = len(vectorized_datasets["train"])
|
||||||
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
||||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||||
|
|
||||||
# Gather the indexes for creating the batch and do a training step
|
# Gather the indexes for creating the batch and do a training step
|
||||||
@@ -574,7 +575,8 @@ def main():
|
|||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
num_eval_samples = len(vectorized_datasets["validation"])
|
num_eval_samples = len(vectorized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(num_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
|
|||||||
@@ -433,7 +433,7 @@ def eval_step(params, batch):
|
|||||||
return compute_metrics(logits, targets, token_mask)
|
return compute_metrics(logits, targets, token_mask)
|
||||||
|
|
||||||
|
|
||||||
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||||
nb_samples = len(samples_idx)
|
nb_samples = len(samples_idx)
|
||||||
samples_to_remove = nb_samples % batch_size
|
samples_to_remove = nb_samples % batch_size
|
||||||
|
|
||||||
@@ -639,7 +639,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Generate an epoch by shuffling sampling indices from the train dataset
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
||||||
nb_training_samples = len(tokenized_datasets["train"])
|
nb_training_samples = len(tokenized_datasets["train"])
|
||||||
training_samples_idx = jax.random.permutation(training_rng, jnp.arange(nb_training_samples))
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
training_samples_idx = np.random.permutation(np.arange(nb_training_samples))
|
||||||
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
training_batch_idx = generate_batch_splits(training_samples_idx, batch_size)
|
||||||
|
|
||||||
# Gather the indexes for creating the batch and do a training step
|
# Gather the indexes for creating the batch and do a training step
|
||||||
@@ -658,7 +659,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
nb_eval_samples = len(tokenized_datasets["validation"])
|
nb_eval_samples = len(tokenized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(nb_eval_samples)
|
# Avoid using jax.numpy here in case of TPU training
|
||||||
|
eval_samples_idx = np.arange(nb_eval_samples)
|
||||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||||
|
|
||||||
eval_metrics = []
|
eval_metrics = []
|
||||||
|
|||||||
Reference in New Issue
Block a user