[Flax] Fix incomplete batches in example scripts (#17863)
* [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports
This commit is contained in:
@@ -43,6 +43,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
@@ -326,15 +327,20 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
|
||||
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
|
||||
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
|
||||
num_samples = len(samples_idx)
|
||||
samples_to_remove = num_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
if drop_last:
|
||||
samples_to_remove = num_samples % batch_size
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
samples_idx = samples_idx.reshape((sections_split, batch_size))
|
||||
else:
|
||||
sections_split = math.ceil(num_samples / batch_size)
|
||||
samples_idx = np.array_split(samples_idx, sections_split)
|
||||
return samples_idx
|
||||
|
||||
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
@@ -632,12 +638,14 @@ def main():
|
||||
config,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
|
||||
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
||||
|
||||
@@ -796,7 +804,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -804,8 +812,9 @@ def main():
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
@@ -835,7 +844,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
|
||||
eval_metrics = []
|
||||
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -843,8 +852,9 @@ def main():
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
|
||||
Reference in New Issue
Block a user