Update examples with image processors (#21155)
* Update examples to use image processors * Small fixes * Resolve conflicts
This commit is contained in:
@@ -29,7 +29,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
ViTFeatureExtractor,
|
||||
ViTImageProcessor,
|
||||
ViTMAEConfig,
|
||||
ViTMAEForPreTraining,
|
||||
)
|
||||
@@ -102,7 +102,7 @@ class DataTrainingArguments:
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
|
||||
Arguments pertaining to which model/config/image processor we are going to pre-train.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
@@ -132,7 +132,7 @@ class ModelArguments:
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -230,7 +230,7 @@ def main():
|
||||
ds["train"] = split["train"]
|
||||
ds["validation"] = split["test"]
|
||||
|
||||
# Load pretrained model and feature extractor
|
||||
# Load pretrained model and image processor
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
@@ -260,13 +260,13 @@ def main():
|
||||
}
|
||||
)
|
||||
|
||||
# create feature extractor
|
||||
if model_args.feature_extractor_name:
|
||||
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
|
||||
# create image processor
|
||||
if model_args.image_processor_name:
|
||||
image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
feature_extractor = ViTFeatureExtractor()
|
||||
image_processor = ViTImageProcessor()
|
||||
|
||||
# create model
|
||||
if model_args.model_name_or_path:
|
||||
@@ -298,17 +298,17 @@ def main():
|
||||
|
||||
# transformations as done in original MAE paper
|
||||
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
|
||||
if "shortest_edge" in feature_extractor.size:
|
||||
size = feature_extractor.size["shortest_edge"]
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
transforms = Compose(
|
||||
[
|
||||
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -349,7 +349,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=ds["train"] if training_args.do_train else None,
|
||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
||||
tokenizer=feature_extractor,
|
||||
tokenizer=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalF
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
IMAGE_PROCESSOR_MAPPING,
|
||||
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoImageProcessor,
|
||||
AutoModelForMaskedImageModeling,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
@@ -115,7 +115,7 @@ class DataTrainingArguments:
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
|
||||
Arguments pertaining to which model/config/image processor we are going to pre-train.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
@@ -152,7 +152,7 @@ class ModelArguments:
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -334,17 +334,16 @@ def main():
|
||||
}
|
||||
)
|
||||
|
||||
# create feature extractor
|
||||
if model_args.feature_extractor_name:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
|
||||
# create image processor
|
||||
if model_args.image_processor_name:
|
||||
image_processor = AutoImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
image_processor = AutoImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
FEATURE_EXTRACTOR_TYPES = {
|
||||
conf.model_type: feature_extractor_class
|
||||
for conf, feature_extractor_class in FEATURE_EXTRACTOR_MAPPING.items()
|
||||
IMAGE_PROCESSOR_TYPES = {
|
||||
conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items()
|
||||
}
|
||||
feature_extractor = FEATURE_EXTRACTOR_TYPES[model_args.model_type]()
|
||||
image_processor = IMAGE_PROCESSOR_TYPES[model_args.model_type]()
|
||||
|
||||
# create model
|
||||
if model_args.model_name_or_path:
|
||||
@@ -382,7 +381,7 @@ def main():
|
||||
RandomResizedCrop(model_args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -427,7 +426,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=ds["train"] if training_args.do_train else None,
|
||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
||||
tokenizer=feature_extractor,
|
||||
tokenizer=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user