Update examples with image processors (#21155)
* Update examples to use image processors * Small fixes * Resolve conflicts
This commit is contained in:
@@ -47,7 +47,7 @@ from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
AutoImageProcessor,
|
||||
AutoTokenizer,
|
||||
FlaxVisionEncoderDecoderModel,
|
||||
HfArgumentParser,
|
||||
@@ -106,12 +106,12 @@ class TrainingArguments:
|
||||
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
||||
)
|
||||
_block_size_doc = """
|
||||
The default value `0` will preprocess (tokenization + feature extraction) the whole dataset before training and
|
||||
The default value `0` will preprocess (tokenization + image processing) the whole dataset before training and
|
||||
cache the results. This uses more disk space, but avoids (repeated) processing time during training. This is a
|
||||
good option if your disk space is large enough to store the whole processed dataset.
|
||||
If a positive value is given, the captions in the dataset will be tokenized before training and the results are
|
||||
cached. During training, it iterates the dataset in chunks of size `block_size`. On each block, images are
|
||||
transformed by the feature extractor with the results being kept in memory (no cache), and batches of size
|
||||
transformed by the image processor with the results being kept in memory (no cache), and batches of size
|
||||
`batch_size` are yielded before processing the next block. This could avoid the heavy disk usage when the
|
||||
dataset is large.
|
||||
"""
|
||||
@@ -477,7 +477,7 @@ def main():
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
@@ -546,7 +546,7 @@ def main():
|
||||
for image_file in examples[image_column]:
|
||||
try:
|
||||
image = Image.open(image_file)
|
||||
feature_extractor(images=image, return_tensors="np")
|
||||
image_processor(images=image, return_tensors="np")
|
||||
bools.append(True)
|
||||
except Exception:
|
||||
bools.append(False)
|
||||
@@ -582,9 +582,9 @@ def main():
|
||||
|
||||
return model_inputs
|
||||
|
||||
def feature_extraction_fn(examples, check_image=True):
|
||||
def image_processing_fn(examples, check_image=True):
|
||||
"""
|
||||
Run feature extraction on images
|
||||
Run preprocessing on images
|
||||
|
||||
If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded.
|
||||
Otherwise, an exception will be thrown.
|
||||
@@ -609,18 +609,18 @@ def main():
|
||||
else:
|
||||
images = [Image.open(image_file) for image_file in examples[image_column]]
|
||||
|
||||
encoder_inputs = feature_extractor(images=images, return_tensors="np")
|
||||
encoder_inputs = image_processor(images=images, return_tensors="np")
|
||||
model_inputs["pixel_values"] = encoder_inputs.pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_fn(examples, max_target_length, check_image=True):
|
||||
"""Run tokenization + image feature extraction"""
|
||||
"""Run tokenization + image processing"""
|
||||
|
||||
model_inputs = {}
|
||||
# This contains image path column
|
||||
model_inputs.update(tokenization_fn(examples, max_target_length))
|
||||
model_inputs.update(feature_extraction_fn(model_inputs, check_image=check_image))
|
||||
model_inputs.update(image_processing_fn(model_inputs, check_image=check_image))
|
||||
# Remove image path column
|
||||
model_inputs.pop(image_column)
|
||||
|
||||
@@ -644,15 +644,15 @@ def main():
|
||||
}
|
||||
)
|
||||
|
||||
# If `block_size` is `0`, tokenization & image feature extraction is done at the beginning
|
||||
run_feat_ext_at_beginning = training_args.block_size == 0
|
||||
# If `block_size` is `0`, tokenization & image processing is done at the beginning
|
||||
run_img_proc_at_beginning = training_args.block_size == 0
|
||||
# Used in .map() below
|
||||
function_kwarg = preprocess_fn if run_feat_ext_at_beginning else tokenization_fn
|
||||
function_kwarg = preprocess_fn if run_img_proc_at_beginning else tokenization_fn
|
||||
# `features` is used only for the final preprocessed dataset (for the performance purpose).
|
||||
features_kwarg = features if run_feat_ext_at_beginning else None
|
||||
# Keep `image_column` if the feature extraction is done during training
|
||||
remove_columns_kwarg = [x for x in column_names if x != image_column or run_feat_ext_at_beginning]
|
||||
processor_names = "tokenizer and feature extractor" if run_feat_ext_at_beginning else "tokenizer"
|
||||
features_kwarg = features if run_img_proc_at_beginning else None
|
||||
# Keep `image_column` if the image processing is done during training
|
||||
remove_columns_kwarg = [x for x in column_names if x != image_column or run_img_proc_at_beginning]
|
||||
processor_names = "tokenizer and image processor" if run_img_proc_at_beginning else "tokenizer"
|
||||
|
||||
# Store some constant
|
||||
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
||||
@@ -671,9 +671,9 @@ def main():
|
||||
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
||||
train_dataset = train_dataset.select(range(max_train_samples))
|
||||
# remove problematic examples
|
||||
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
||||
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
|
||||
# instead here.)
|
||||
if not run_feat_ext_at_beginning:
|
||||
if not run_img_proc_at_beginning:
|
||||
train_dataset = train_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
||||
train_dataset = train_dataset.map(
|
||||
function=function_kwarg,
|
||||
@@ -686,7 +686,7 @@ def main():
|
||||
fn_kwargs={"max_target_length": data_args.max_target_length},
|
||||
features=features_kwarg,
|
||||
)
|
||||
if run_feat_ext_at_beginning:
|
||||
if run_img_proc_at_beginning:
|
||||
# set format (for performance) since the dataset is ready to be used
|
||||
train_dataset = train_dataset.with_format("numpy")
|
||||
|
||||
@@ -705,9 +705,9 @@ def main():
|
||||
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
||||
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
||||
# remove problematic examples
|
||||
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
||||
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
|
||||
# instead here.)
|
||||
if not run_feat_ext_at_beginning:
|
||||
if not run_img_proc_at_beginning:
|
||||
eval_dataset = eval_dataset.filter(filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers)
|
||||
eval_dataset = eval_dataset.map(
|
||||
function=function_kwarg,
|
||||
@@ -720,7 +720,7 @@ def main():
|
||||
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
||||
features=features_kwarg,
|
||||
)
|
||||
if run_feat_ext_at_beginning:
|
||||
if run_img_proc_at_beginning:
|
||||
# set format (for performance) since the dataset is ready to be used
|
||||
eval_dataset = eval_dataset.with_format("numpy")
|
||||
|
||||
@@ -735,9 +735,9 @@ def main():
|
||||
max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
|
||||
predict_dataset = predict_dataset.select(range(max_predict_samples))
|
||||
# remove problematic examples
|
||||
# (if feature extraction is performed at the beginning, the filtering is done during preprocessing below
|
||||
# (if image processing is performed at the beginning, the filtering is done during preprocessing below
|
||||
# instead here.)
|
||||
if not run_feat_ext_at_beginning:
|
||||
if not run_img_proc_at_beginning:
|
||||
predict_dataset = predict_dataset.filter(
|
||||
filter_fn, batched=True, num_proc=data_args.preprocessing_num_workers
|
||||
)
|
||||
@@ -752,7 +752,7 @@ def main():
|
||||
fn_kwargs={"max_target_length": data_args.val_max_target_length},
|
||||
features=features_kwarg,
|
||||
)
|
||||
if run_feat_ext_at_beginning:
|
||||
if run_img_proc_at_beginning:
|
||||
# set format (for performance) since the dataset is ready to be used
|
||||
predict_dataset = predict_dataset.with_format("numpy")
|
||||
|
||||
@@ -771,8 +771,8 @@ def main():
|
||||
"""
|
||||
Wrap the simple `data_loader` in a block-wise way if `block_size` > 0, else it's the same as `data_loader`.
|
||||
|
||||
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image feature
|
||||
extraction (with the column name being specified by `image_column`). The tokenization should be done before
|
||||
If `block_size` > 0, it requires `ds` to have a column that gives image paths in order to perform image
|
||||
processing (with the column name being specified by `image_column`). The tokenization should be done before
|
||||
training in this case.
|
||||
"""
|
||||
|
||||
@@ -804,7 +804,7 @@ def main():
|
||||
_ds = ds.select(selected_indices)
|
||||
|
||||
_ds = _ds.map(
|
||||
feature_extraction_fn,
|
||||
image_processing_fn,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=[image_column],
|
||||
@@ -813,7 +813,7 @@ def main():
|
||||
keep_in_memory=keep_in_memory,
|
||||
# The images are already checked either in `.filter()` or in `preprocess_fn()`
|
||||
fn_kwargs={"check_image": False},
|
||||
desc=f"Running feature extraction on {split} dataset".replace(" ", " "),
|
||||
desc=f"Running image processing on {split} dataset".replace(" ", " "),
|
||||
)
|
||||
_ds = _ds.with_format("numpy")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user