Update examples with image processors (#21155)

* Update examples to use image processors

* Small fixes

* Resolve conflicts
This commit is contained in:
amyeroberts
2023-01-19 15:14:58 +00:00
committed by GitHub
parent fc8a93507c
commit 4bc18e7a83
12 changed files with 124 additions and 137 deletions

View File

@@ -29,7 +29,7 @@ from transformers import (
HfArgumentParser,
Trainer,
TrainingArguments,
ViTFeatureExtractor,
ViTImageProcessor,
ViTMAEConfig,
ViTMAEForPreTraining,
)
@@ -102,7 +102,7 @@ class DataTrainingArguments:
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
Arguments pertaining to which model/config/image processor we are going to pre-train.
"""
model_name_or_path: str = field(
@@ -132,7 +132,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@@ -230,7 +230,7 @@ def main():
ds["train"] = split["train"]
ds["validation"] = split["test"]
# Load pretrained model and feature extractor
# Load pretrained model and image processor
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
@@ -260,13 +260,13 @@ def main():
}
)
# create feature extractor
if model_args.feature_extractor_name:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
# create image processor
if model_args.image_processor_name:
image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
elif model_args.model_name_or_path:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
feature_extractor = ViTFeatureExtractor()
image_processor = ViTImageProcessor()
# create model
if model_args.model_name_or_path:
@@ -298,17 +298,17 @@ def main():
# transformations as done in original MAE paper
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
if "shortest_edge" in feature_extractor.size:
size = feature_extractor.size["shortest_edge"]
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (feature_extractor.size["height"], feature_extractor.size["width"])
size = (image_processor.size["height"], image_processor.size["width"])
transforms = Compose(
[
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@@ -349,7 +349,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=collate_fn,
)

View File

@@ -27,10 +27,10 @@ from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalF
import transformers
from transformers import (
CONFIG_MAPPING,
FEATURE_EXTRACTOR_MAPPING,
IMAGE_PROCESSOR_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutoModelForMaskedImageModeling,
HfArgumentParser,
Trainer,
@@ -115,7 +115,7 @@ class DataTrainingArguments:
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
Arguments pertaining to which model/config/image processor we are going to pre-train.
"""
model_name_or_path: str = field(
@@ -152,7 +152,7 @@ class ModelArguments:
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
use_auth_token: bool = field(
default=False,
metadata={
@@ -334,17 +334,16 @@ def main():
}
)
# create feature extractor
if model_args.feature_extractor_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
# create image processor
if model_args.image_processor_name:
image_processor = AutoImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
elif model_args.model_name_or_path:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
image_processor = AutoImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
else:
FEATURE_EXTRACTOR_TYPES = {
conf.model_type: feature_extractor_class
for conf, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items()
IMAGE_PROCESSOR_TYPES = {
conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items()
}
feature_extractor = FEATURE_EXTRACTOR_TYPES[model_args.model_type]()
image_processor = IMAGE_PROCESSOR_TYPES[model_args.model_type]()
# create model
if model_args.model_name_or_path:
@@ -382,7 +381,7 @@ def main():
RandomResizedCrop(model_args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
]
)
@@ -427,7 +426,7 @@ def main():
args=training_args,
train_dataset=ds["train"] if training_args.do_train else None,
eval_dataset=ds["validation"] if training_args.do_eval else None,
tokenizer=feature_extractor,
tokenizer=image_processor,
data_collator=collate_fn,
)