[Flax] Fix incomplete batches in example scripts (#17863)

* [Flax] Fix incomplete batches in example scripts

* fix dataloader batching

* convert jnp batch idxs to np array

* add missing `pad_shard_unpad` to final prediction generate step

* only `pad_shard_unpad` at inference time

* merge conflicts

* remove incomplete batch step from eval

* fix run_qa.py

* add `pad_shard_unpad` to run_flax_ner.py

* add `pad_shard_unpad` to run_flax_glue.py

* add `pad_shard_unpad` to run_image_classification.py

* make style

* fix mlm flax eval batches

* remove redundant imports
This commit is contained in:
Sanchit Gandhi
2022-07-27 15:50:47 +01:00
committed by GitHub
parent 9caf68a638
commit 7490a97cac
8 changed files with 180 additions and 197 deletions

View File

@@ -40,7 +40,7 @@ import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
@@ -368,7 +368,8 @@ def main():
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
@@ -398,7 +399,7 @@ def main():
shuffle=False,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
drop_last=False,
collate_fn=collate_fn,
)
@@ -532,8 +533,9 @@ def main():
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader:
# Model forward
batch = shard(batch)
metrics = p_eval_step(state.params, batch)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
eval_step_progress_bar.update(1)