Update examples with image processors (#21155)

* Update examples to use image processors

* Small fixes

* Resolve conflicts
This commit is contained in:
amyeroberts
2023-01-19 15:14:58 +00:00
committed by GitHub
parent fc8a93507c
commit 4bc18e7a83
12 changed files with 124 additions and 137 deletions

View File

@@ -29,7 +29,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
ViTFeatureExtractor,
ViTImageProcessor,
ViTMAEConfig,
ViTMAEForPreTraining,
)
@@ -102,7 +102,7 @@ class DataTrainingArguments:
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
Arguments pertaining to which model/config/image processor we are going to pre-train.
"""
model_name_or_path: str = field(
@@ -132,7 +132,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@@ -230,7 +230,7 @@ def main():
ds["train"] = split["train"]
ds["validation"] = split["test"]
# Load pretrained model and feature extractor
# Load pretrained model and image processor
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
@@ -260,13 +260,13 @@ def main():
}
)
# create feature extractor
if model_args.feature_extractor_name:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
# create image processor
if model_args.image_processor_name:
image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
elif model_args.model_name_or_path:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
feature_extractor = ViTFeatureExtractor()
image_processor = ViTImageProcessor()
# create model
if model_args.model_name_or_path:
@@ -298,17 +298,17 @@ def main():
# transformations as done in original MAE paper
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
size = (image_processor.size["height"], image_processor.size["width"])
transforms = Compose(
[
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@@ -349,7 +349,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=collate_fn,
)