[Flax] Fix incomplete batches in example scripts (#17863)
* [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports
This commit is contained in:
@@ -20,6 +20,7 @@ Fine-tuning the library models for summarization.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -41,7 +42,7 @@ import optax
|
||||
import transformers
|
||||
from filelock import FileLock
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository
|
||||
@@ -335,26 +336,28 @@ class TrainState(train_state.TrainState):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||
"""
|
||||
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
||||
Shuffle batches if `shuffle` is `True`.
|
||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||
"""
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.asarray(batch_idx)
|
||||
else:
|
||||
batch_idx = jnp.arange(len(dataset))
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
if drop_last:
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
else:
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: jnp.array(v) for k, v in batch.items()}
|
||||
|
||||
batch = shard(batch)
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
yield batch
|
||||
|
||||
@@ -706,7 +709,8 @@ def main():
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
@@ -850,6 +854,7 @@ def main():
|
||||
# train
|
||||
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
||||
batch = next(train_loader)
|
||||
batch = shard(batch)
|
||||
state, train_metric = p_train_step(state, batch)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
@@ -867,21 +872,23 @@ def main():
|
||||
eval_preds = []
|
||||
eval_labels = []
|
||||
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
||||
eval_steps = len(eval_dataset) // eval_batch_size
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
||||
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = next(eval_loader)
|
||||
labels = batch["labels"]
|
||||
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# generation
|
||||
if data_args.predict_with_generate:
|
||||
generated_ids = p_generate_step(state.params, batch)
|
||||
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
||||
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
||||
eval_labels.extend(labels)
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
@@ -920,21 +927,23 @@ def main():
|
||||
pred_generations = []
|
||||
pred_labels = []
|
||||
|
||||
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
|
||||
pred_steps = len(predict_dataset) // eval_batch_size
|
||||
pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False)
|
||||
pred_steps = math.ceil(len(predict_dataset) / eval_batch_size)
|
||||
for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = next(pred_loader)
|
||||
labels = batch["labels"]
|
||||
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
pred_metrics.append(metrics)
|
||||
|
||||
# generation
|
||||
if data_args.predict_with_generate:
|
||||
generated_ids = p_generate_step(state.params, batch)
|
||||
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch)
|
||||
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
||||
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
||||
pred_labels.extend(labels)
|
||||
|
||||
# normalize prediction metrics
|
||||
pred_metrics = get_metrics(pred_metrics)
|
||||
|
||||
Reference in New Issue
Block a user