[Flax] Fix incomplete batches in example scripts (#17863)

* [Flax] Fix incomplete batches in example scripts

* fix dataloader batching

* convert jnp batch idxs to np array

* add missing `pad_shard_unpad` to final prediction generate step

* only `pad_shard_unpad` at inference time

* merge conflicts

* remove incomplete batch step from eval

* fix run_qa.py

* add `pad_shard_unpad` to run_flax_ner.py

* add `pad_shard_unpad` to run_flax_glue.py

* add `pad_shard_unpad` to run_image_classification.py

* make style

* fix mlm flax eval batches

* remove redundant imports
This commit is contained in:
Sanchit Gandhi
2022-07-27 15:50:47 +01:00
committed by GitHub
parent 9caf68a638
commit 7490a97cac
8 changed files with 180 additions and 197 deletions

View File

@@ -16,12 +16,12 @@
""" Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import json
import logging
import math
import os
import random
import sys
import time
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
@@ -35,7 +35,7 @@ import jax.numpy as jnp
import optax
import transformers
from flax import struct, traverse_util
from flax.jax_utils import replicate, unreplicate
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
@@ -300,11 +300,15 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
batch_idx = np.arange(len(dataset))
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
@@ -521,8 +525,9 @@ def main():
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
learning_rate_fn = create_learning_rate_fn(
len(train_dataset),
@@ -621,26 +626,15 @@ def main():
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(
eval_loader,
total=len(eval_dataset) // eval_batch_size,
total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels")
predictions = p_eval_step(state, batch)
metric.add_batch(predictions=chain(*predictions), references=chain(*labels))
# evaluate also on leftover examples (not divisible by batch_size)
num_leftover_samples = len(eval_dataset) % eval_batch_size
# make sure leftover batch is evaluated on one device
if num_leftover_samples > 0 and jax.process_index() == 0:
# take leftover samples
batch = eval_dataset[-num_leftover_samples:]
batch = {k: np.array(v) for k, v in batch.items()}
labels = batch.pop("labels")
predictions = eval_step(unreplicate(state), batch)
metric.add_batch(predictions=predictions, references=labels)
predictions = pad_shard_unpad(p_eval_step)(
state, batch, min_device_batch=per_device_eval_batch_size
)
metric.add_batch(predictions=np.array(predictions), references=labels)
eval_metric = metric.compute()