[Flax] Fix incomplete batches in example scripts (#17863)
* [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports
This commit is contained in:
@@ -43,7 +43,7 @@ import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.jax_utils import pad_shard_unpad, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository
|
||||
@@ -264,20 +264,24 @@ class TrainState(train_state.TrainState):
|
||||
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
||||
|
||||
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
|
||||
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
|
||||
"""
|
||||
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
||||
Shuffle batches if `shuffle` is `True`.
|
||||
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
||||
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
||||
"""
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
|
||||
if shuffle:
|
||||
batch_idx = jax.random.permutation(rng, len(dataset))
|
||||
batch_idx = np.asarray(batch_idx)
|
||||
else:
|
||||
batch_idx = jnp.arange(len(dataset))
|
||||
batch_idx = np.arange(len(dataset))
|
||||
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
if drop_last:
|
||||
steps_per_epoch = len(dataset) // batch_size
|
||||
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
|
||||
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
|
||||
else:
|
||||
steps_per_epoch = math.ceil(len(dataset) / batch_size)
|
||||
batch_idx = np.array_split(batch_idx, steps_per_epoch)
|
||||
|
||||
for idx in batch_idx:
|
||||
batch = dataset[idx]
|
||||
@@ -621,7 +625,8 @@ def main():
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
steps_per_epoch = len(train_dataset) // train_batch_size
|
||||
total_train_steps = steps_per_epoch * num_epochs
|
||||
|
||||
@@ -764,13 +769,14 @@ def main():
|
||||
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
||||
# ======================== Evaluating ==============================
|
||||
eval_metrics = []
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
||||
eval_steps = len(eval_dataset) // eval_batch_size
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
||||
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = next(eval_loader)
|
||||
batch = shard(batch)
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
@@ -806,12 +812,14 @@ def main():
|
||||
# Eval after training
|
||||
if training_args.do_eval:
|
||||
eval_metrics = []
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
||||
eval_steps = len(eval_dataset) // eval_batch_size
|
||||
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
|
||||
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
|
||||
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
||||
# Model forward
|
||||
batch = shard(next(eval_loader))
|
||||
metrics = p_eval_step(state.params, batch)
|
||||
batch = next(eval_loader)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, batch, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
|
||||
@@ -43,6 +43,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
@@ -326,15 +327,20 @@ class FlaxDataCollatorForLanguageModeling:
|
||||
return inputs, labels
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
|
||||
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
|
||||
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
|
||||
num_samples = len(samples_idx)
|
||||
samples_to_remove = num_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
if drop_last:
|
||||
samples_to_remove = num_samples % batch_size
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
samples_idx = samples_idx.reshape((sections_split, batch_size))
|
||||
else:
|
||||
sections_split = math.ceil(num_samples / batch_size)
|
||||
samples_idx = np.array_split(samples_idx, sections_split)
|
||||
return samples_idx
|
||||
|
||||
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
@@ -632,12 +638,14 @@ def main():
|
||||
config,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
|
||||
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
||||
|
||||
@@ -796,7 +804,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -804,8 +812,9 @@ def main():
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
@@ -835,7 +844,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
|
||||
eval_metrics = []
|
||||
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -843,8 +852,9 @@ def main():
|
||||
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# normalize eval metrics
|
||||
|
||||
@@ -21,6 +21,7 @@ https://huggingface.co/models?filter=t5
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -41,6 +42,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository
|
||||
@@ -326,6 +328,7 @@ class FlaxDataCollatorForT5MLM:
|
||||
decoder_start_token_id: int
|
||||
|
||||
def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
|
||||
|
||||
# convert list to dict and tensorize input
|
||||
batch = BatchEncoding(
|
||||
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
|
||||
@@ -394,6 +397,7 @@ class FlaxDataCollatorForT5MLM:
|
||||
return input_ids
|
||||
|
||||
def random_spans_noise_mask(self, length):
|
||||
|
||||
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
||||
|
||||
Noise mask consisting of random spans of noise tokens.
|
||||
@@ -457,15 +461,20 @@ class FlaxDataCollatorForT5MLM:
|
||||
return is_noise[:orig_length]
|
||||
|
||||
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray:
|
||||
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
|
||||
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
|
||||
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
|
||||
num_samples = len(samples_idx)
|
||||
samples_to_remove = num_samples % batch_size
|
||||
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
batch_idx = np.split(samples_idx, sections_split)
|
||||
return batch_idx
|
||||
if drop_last:
|
||||
samples_to_remove = num_samples % batch_size
|
||||
if samples_to_remove != 0:
|
||||
samples_idx = samples_idx[:-samples_to_remove]
|
||||
sections_split = num_samples // batch_size
|
||||
samples_idx = samples_idx.reshape((sections_split, batch_size))
|
||||
else:
|
||||
sections_split = math.ceil(num_samples / batch_size)
|
||||
samples_idx = np.array_split(samples_idx, sections_split)
|
||||
return samples_idx
|
||||
|
||||
|
||||
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||
@@ -737,6 +746,7 @@ def main():
|
||||
config,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Data collator
|
||||
@@ -754,7 +764,8 @@ def main():
|
||||
# Store some constant
|
||||
num_epochs = int(training_args.num_train_epochs)
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
||||
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
||||
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
||||
|
||||
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
||||
|
||||
@@ -915,7 +926,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -923,8 +934,9 @@ def main():
|
||||
model_inputs = data_collator(samples)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# get eval metrics
|
||||
@@ -952,7 +964,7 @@ def main():
|
||||
num_eval_samples = len(tokenized_datasets["validation"])
|
||||
# Avoid using jax.numpy here in case of TPU training
|
||||
eval_samples_idx = np.arange(num_eval_samples)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
||||
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
||||
|
||||
eval_metrics = []
|
||||
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
||||
@@ -960,8 +972,9 @@ def main():
|
||||
model_inputs = data_collator(samples)
|
||||
|
||||
# Model forward
|
||||
model_inputs = shard(model_inputs.data)
|
||||
metrics = p_eval_step(state.params, model_inputs)
|
||||
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
||||
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
||||
)
|
||||
eval_metrics.append(metrics)
|
||||
|
||||
# get eval metrics
|
||||
|
||||
Reference in New Issue
Block a user