Update examples with image processors (#21155)
* Update examples to use image processors * Small fixes * Resolve conflicts
This commit is contained in:
@@ -29,7 +29,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
ViTFeatureExtractor,
|
||||
ViTImageProcessor,
|
||||
ViTMAEConfig,
|
||||
ViTMAEForPreTraining,
|
||||
)
|
||||
@@ -102,7 +102,7 @@ class DataTrainingArguments:
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/feature extractor we are going to pre-train.
|
||||
Arguments pertaining to which model/config/image processor we are going to pre-train.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
@@ -132,7 +132,7 @@ class ModelArguments:
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -230,7 +230,7 @@ def main():
|
||||
ds["train"] = split["train"]
|
||||
ds["validation"] = split["test"]
|
||||
|
||||
# Load pretrained model and feature extractor
|
||||
# Load pretrained model and image processor
|
||||
#
|
||||
# Distributed training:
|
||||
# The .from_pretrained methods guarantee that only one local process can concurrently
|
||||
@@ -260,13 +260,13 @@ def main():
|
||||
}
|
||||
)
|
||||
|
||||
# create feature extractor
|
||||
if model_args.feature_extractor_name:
|
||||
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.feature_extractor_name, **config_kwargs)
|
||||
# create image processor
|
||||
if model_args.image_processor_name:
|
||||
image_processor = ViTImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs)
|
||||
elif model_args.model_name_or_path:
|
||||
feature_extractor = ViTFeatureExtractor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
image_processor = ViTImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
else:
|
||||
feature_extractor = ViTFeatureExtractor()
|
||||
image_processor = ViTImageProcessor()
|
||||
|
||||
# create model
|
||||
if model_args.model_name_or_path:
|
||||
@@ -298,17 +298,17 @@ def main():
|
||||
|
||||
# transformations as done in original MAE paper
|
||||
# source: https://github.com/facebookresearch/mae/blob/main/main_pretrain.py
|
||||
if "shortest_edge" in feature_extractor.size:
|
||||
size = feature_extractor.size["shortest_edge"]
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
transforms = Compose(
|
||||
[
|
||||
Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
|
||||
RandomResizedCrop(size, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
|
||||
Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -349,7 +349,7 @@ def main():
|
||||
args=training_args,
|
||||
train_dataset=ds["train"] if training_args.do_train else None,
|
||||
eval_dataset=ds["validation"] if training_args.do_eval else None,
|
||||
tokenizer=feature_extractor,
|
||||
tokenizer=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user