Update examples with image processors (#21155)

* Update examples to use image processors

* Small fixes

* Resolve conflicts
This commit is contained in:
amyeroberts
2023-01-19 15:14:58 +00:00
committed by GitHub
parent fc8a93507c
commit 4bc18e7a83
12 changed files with 124 additions and 137 deletions

View File

@@ -47,7 +47,7 @@ from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from transformers import (
AutoFeatureExtractor,
AutoImageProcessor,
AutoTokenizer,
FlaxVisionEncoderDecoderModel,
HfArgumentParser,
@@ -106,12 +106,12 @@ class TrainingArguments:
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
_block_size_doc = """
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
The default value `0` will preprocess (tokenization + image processing) the whole dataset before training and
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
good option if your disk space is large enough to store the whole processed dataset.
If a positive value is given, the captions in the dataset will be tokenized before training and the results are
cached. During training, it iterates the dataset in chunks of size `block_size`. On each block, images are
transformed by the feature extractor with the results being kept in memory (no cache), and batches of size
transformed by the image processor with the results being kept in memory (no cache), and batches of size
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
dataset is large.
"""
@@ -477,7 +477,7 @@ def main():
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
image_processor = AutoImageProcessor.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
@@ -546,7 +546,7 @@ def main():
for image_file in examples[image_column]:
try:
image = Image.open(image_file)
feature_extractor(images=image, return_tensors="np")
image_processor(images=image, return_tensors="np")
bools.append(True)
except Exception:
bools.append(False)
@@ -582,9 +582,9 @@ def main():
return model_inputs
def feature_extraction_fn(examples, check_image=True):
def image_processing_fn(examples, check_image=True):
"""
Run feature extraction on images
Run preprocessing on images
If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
Otherwise, an exception will be thrown.
@@ -609,18 +609,18 @@ def main():
else:
images = [Image.open(image_file) for image_file in examples[image_column]]
encoder_inputs = feature_extractor(images=images, return_tensors="np")
encoder_inputs = image_processor(images=images, return_tensors="np")
model_inputs["pixel_values"] = encoder_inputs.pixel_values
return model_inputs
def preprocess_fn(examples, max_target_length, check_image=True):
"""Run tokenization + image feature extraction"""
"""Run tokenization + image processing"""
model_inputs = {}
# This contains image path column
model_inputs.update(tokenization_fn(examples, max_target_length))
model_inputs.update(feature_extraction_fn(model_inputs, check_image=check_image))
model_inputs.update(image_processing_fn(model_inputs, check_image=check_image))
# Remove image path column
model_inputs.pop(image_column)
@@ -644,15 +644,15 @@ def main():
}
)
# If `block_size` is `0`, tokenization & image feature extraction is done at the beginning
run_feat_ext_at_beginning = training_args.block_size == 0
# If `block_size` is `0`, tokenization & image processing is done at the beginning
run_img_proc_at_beginning = training_args.block_size == 0
# Used in .map() below
function_kwarg = preprocess_fn if run_feat_ext_at_beginning else tokenization_fn
function_kwarg = preprocess_fn if run_img_proc_at_beginning else tokenization_fn
# `features` is used only for the final preprocessed dataset (for the performance purpose).
features_kwarg = features if run_feat_ext_at_beginning else None
# Keep `image_column` if the feature extraction is done during training
remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_at_beginning]
processor_names = "tokenizer and feature extractor" if run_feat_ext_at_beginning else "tokenizer"
features_kwarg = features if run_img_proc_at_beginning else None
# Keep `image_column` if the image processing is done during training
remove_columns_kwarg = [x for x in column_names if x != image_column or run_img_proc_at_beginning]
processor_names = "tokenizer and image processor" if run_img_proc_at_beginning else "tokenizer"
# Store some constant
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
@@ -671,9 +671,9 @@ def main():
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
if not run_feat_ext_at_beginning:
if not run_img_proc_at_beginning:
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
train_dataset = train_dataset.map(
function=function_kwarg,
@@ -686,7 +686,7 @@ def main():
fn_kwargs={"max_target_length": data_args.max_target_length},
features=features_kwarg,
)
if run_feat_ext_at_beginning:
if run_img_proc_at_beginning:
# set format (for performance) since the dataset is ready to be used
train_dataset = train_dataset.with_format("numpy")
@@ -705,9 +705,9 @@ def main():
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
if not run_feat_ext_at_beginning:
if not run_img_proc_at_beginning:
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
eval_dataset = eval_dataset.map(
function=function_kwarg,
@@ -720,7 +720,7 @@ def main():
fn_kwargs={"max_target_length": data_args.val_max_target_length},
features=features_kwarg,
)
if run_feat_ext_at_beginning:
if run_img_proc_at_beginning:
# set format (for performance) since the dataset is ready to be used
eval_dataset = eval_dataset.with_format("numpy")
@@ -735,9 +735,9 @@ def main():
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
predict_dataset = predict_dataset.select(range(max_predict_samples))
# remove problematic examples
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
# instead here.)
if not run_feat_ext_at_beginning:
if not run_img_proc_at_beginning:
predict_dataset = predict_dataset.filter(
filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers
)
@@ -752,7 +752,7 @@ def main():
fn_kwargs={"max_target_length": data_args.val_max_target_length},
features=features_kwarg,
)
if run_feat_ext_at_beginning:
if run_img_proc_at_beginning:
# set format (for performance) since the dataset is ready to be used
predict_dataset = predict_dataset.with_format("numpy")
@@ -771,8 +771,8 @@ def main():
"""
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image feature
extraction (with the column name being specified by `image_column`). The tokenization should be done before
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image
processing (with the column name being specified by `image_column`). The tokenization should be done before
training in this case.
"""
@@ -804,7 +804,7 @@ def main():
_ds = ds.select(selected_indices)
_ds = _ds.map(
feature_extraction_fn,
image_processing_fn,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[image_column],
@@ -813,7 +813,7 @@ def main():
keep_in_memory=keep_in_memory,
# The images are already checked either in `.filter()` or in `preprocess_fn()`
fn_kwargs={"check_image": False},
desc=f"Running feature extraction on {split} dataset".replace(" ", " "),
desc=f"Running image processing on {split} dataset".replace(" ", " "),
)
_ds = _ds.with_format("numpy")