[Flax] Fix incomplete batches in example scripts (#17863)
* [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
""" Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
@@ -36,7 +37,7 @@ import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import replicate, unreplicate
|
||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
@@ -351,11 +352,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
||||
|
||||
|
||||
def eval_data_collator(dataset: Dataset, batch_size: int):
|
||||
"""Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices."""
|
||||
for i in range(len(dataset) // batch_size):
|
||||
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
||||
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
batch = shard(batch)
|
||||
|
||||
yield batch
|
||||
|
||||
@@ -600,6 +605,7 @@ def main():
|
||||
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
||||
|
||||
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
|
||||
|
||||
learning_rate_fn = create_learning_rate_fn(
|
||||
@@ -728,34 +734,16 @@ def main():
|
||||
# evaluate
|
||||
for batch in tqdm(
|
||||
eval_data_collator(eval_dataset, eval_batch_size),
|
||||
total=len(eval_dataset) // eval_batch_size,
|
||||
total=math.ceil(len(eval_dataset) / eval_batch_size),
|
||||
desc="Evaluating ...",
|
||||
position=2,
|
||||
):
|
||||
labels = batch.pop("labels")
|
||||
predictions = p_eval_step(state, batch)
|
||||
predictions = np.array([pred for pred in chain(*predictions)])
|
||||
labels = np.array([label for label in chain(*labels)])
|
||||
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
|
||||
preds, refs = get_labels(predictions, labels)
|
||||
metric.add_batch(
|
||||
predictions=preds,
|
||||
references=refs,
|
||||
predictions = pad_shard_unpad(p_eval_step)(
|
||||
state, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
|
||||
# evaluate also on leftover examples (not divisible by batch_size)
|
||||
num_leftover_samples = len(eval_dataset) % eval_batch_size
|
||||
|
||||
# make sure leftover batch is evaluated on one device
|
||||
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||
# take leftover samples
|
||||
batch = eval_dataset[-num_leftover_samples:]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
labels = batch.pop("labels")
|
||||
predictions = eval_step(unreplicate(state), batch)
|
||||
labels = np.array(labels)
|
||||
labels[np.array(batch["attention_mask"]) == 0] = -100
|
||||
predictions = np.array(predictions)
|
||||
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
|
||||
preds, refs = get_labels(predictions, labels)
|
||||
metric.add_batch(
|
||||
predictions=preds,
|
||||
@@ -791,28 +779,12 @@ def main():
|
||||
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
|
||||
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
|
||||
labels = batch.pop("labels")
|
||||
predictions = p_eval_step(state, batch)
|
||||
predictions = np.array([pred for pred in chain(*predictions)])
|
||||
labels = np.array([label for label in chain(*labels)])
|
||||
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
|
||||
predictions = np.array(predictions)
|
||||
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
|
||||
preds, refs = get_labels(predictions, labels)
|
||||
metric.add_batch(predictions=preds, references=refs)
|
||||
|
||||
# evaluate also on leftover examples (not divisible by batch_size)
|
||||
num_leftover_samples = len(eval_dataset) % eval_batch_size
|
||||
|
||||
# make sure leftover batch is evaluated on one device
|
||||
if num_leftover_samples > 0 and jax.process_index() == 0:
|
||||
# take leftover samples
|
||||
batch = eval_dataset[-num_leftover_samples:]
|
||||
batch = {k: np.array(v) for k, v in batch.items()}
|
||||
|
||||
labels = np.array(batch.pop("labels"))
|
||||
predictions = eval_step(unreplicate(state), batch)
|
||||
labels[np.array(batch["attention_mask"]) == 0] = -100
|
||||
preds, refs = get_labels(predictions, labels)
|
||||
metric.add_batch(predictions=preds, references=refs)
|
||||
|
||||
eval_metrics = compute_metrics()
|
||||
|
||||
if jax.process_index() == 0:
|
||||
|
||||
Reference in New Issue
Block a user