Update examples with image processors (#21155)
* Update examples to use image processors * Small fixes * Resolve conflicts
This commit is contained in:
@@ -38,7 +38,7 @@ import transformers
|
||||
from transformers import (
|
||||
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoImageProcessor,
|
||||
AutoModelForImageClassification,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
@@ -141,7 +141,7 @@ class ModelArguments:
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@@ -283,19 +283,19 @@ def main():
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
model_args.feature_extractor_name or model_args.model_name_or_path,
|
||||
image_processor = AutoImageProcessor.from_pretrained(
|
||||
model_args.image_processor_name or model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
if "shortest_edge" in feature_extractor.size:
|
||||
size = feature_extractor.size["shortest_edge"]
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
@@ -352,7 +352,7 @@ def main():
|
||||
train_dataset=dataset["train"] if training_args.do_train else None,
|
||||
eval_dataset=dataset["validation"] if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=feature_extractor,
|
||||
tokenizer=image_processor,
|
||||
data_collator=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,13 +41,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoFeatureExtractor,
|
||||
AutoModelForImageClassification,
|
||||
SchedulerType,
|
||||
get_scheduler,
|
||||
)
|
||||
from transformers import AutoConfig, AutoImageProcessor, AutoModelForImageClassification, SchedulerType, get_scheduler
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
@@ -294,7 +288,7 @@ def main():
|
||||
label2id = {label: str(i) for i, label in enumerate(labels)}
|
||||
id2label = {str(i): label for i, label in enumerate(labels)}
|
||||
|
||||
# Load pretrained model and feature extractor
|
||||
# Load pretrained model and image processor
|
||||
#
|
||||
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
|
||||
# download model & vocab.
|
||||
@@ -305,7 +299,7 @@ def main():
|
||||
label2id=label2id,
|
||||
finetuning_task="image-classification",
|
||||
)
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_name_or_path)
|
||||
image_processor = AutoImageProcessor.from_pretrained(args.model_name_or_path)
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
@@ -316,11 +310,11 @@ def main():
|
||||
# Preprocessing the datasets
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
if "shortest_edge" in feature_extractor.size:
|
||||
size = feature_extractor.size["shortest_edge"]
|
||||
if "shortest_edge" in image_processor.size:
|
||||
size = image_processor.size["shortest_edge"]
|
||||
else:
|
||||
size = (feature_extractor.size["height"], feature_extractor.size["width"])
|
||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
size = (image_processor.size["height"], image_processor.size["width"])
|
||||
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
|
||||
train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(size),
|
||||
@@ -505,7 +499,7 @@ def main():
|
||||
save_function=accelerator.save,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
feature_extractor.save_pretrained(args.output_dir)
|
||||
image_processor.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(
|
||||
commit_message=f"Training in progress {completed_steps} steps",
|
||||
blocking=False,
|
||||
@@ -547,7 +541,7 @@ def main():
|
||||
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
feature_extractor.save_pretrained(args.output_dir)
|
||||
image_processor.save_pretrained(args.output_dir)
|
||||
repo.push_to_hub(
|
||||
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
@@ -568,7 +562,7 @@ def main():
|
||||
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
feature_extractor.save_pretrained(args.output_dir)
|
||||
image_processor.save_pretrained(args.output_dir)
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user