[Flax] Fix incomplete batches in example scripts (#17863)
* [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports
This commit is contained in:
@@ -40,7 +40,7 @@ import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository
|
||||
@@ -368,7 +368,8 @@ def main():
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
@@ -398,7 +399,7 @@ def main():
|
||||
shuffle=False,
|
||||
num_workers=data_args.preprocessing_num_workers,
|
||||
persistent_workers=True,
|
||||
drop_last=True,
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
@@ -532,8 +533,9 @@ def main():
|
||||
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
|
||||
for batch in eval_loader:
|
||||
# Model forward
|
||||
batch = shard(batch)
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
eval_step_progress_bar.update(1)
|
||||
|
||||
Reference in New Issue
Block a user