✨ update image classification example (#13824)
* ✨ update image classification example * 📌 update reqs
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
torch>=1.9.0
|
||||
torchvision>=0.10.0
|
||||
torch>=1.5.0
|
||||
torchvision>=0.6.0
|
||||
datasets>=1.8.0
|
||||
@@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||
check_min_version("4.12.0.dev0")
|
||||
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
@@ -102,7 +102,6 @@ class DataTrainingArguments:
|
||||
"value if set."
|
||||
},
|
||||
)
|
||||
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
|
||||
|
||||
def __post_init__(self):
|
||||
data_files = dict()
|
||||
@@ -210,35 +209,6 @@ def main():
|
||||
task="image-classification",
|
||||
)
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(data_args.image_size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(data_args.image_size),
|
||||
CenterCrop(data_args.image_size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def train_transforms(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||
return example_batch
|
||||
|
||||
def val_transforms(example_batch):
|
||||
"""Apply _val_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||
return example_batch
|
||||
|
||||
# If we don't have a validation split, split off a percentage of train as validation.
|
||||
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
|
||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||
@@ -281,20 +251,42 @@ def main():
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
)
|
||||
# NOTE - We aren't directly using this feature extractor since we defined custom transforms above.
|
||||
# We initialize this instance below and pass it to Trainer to ensure that the feature extraction
|
||||
# config, preprocessor_config.json, is included in output directories.
|
||||
# This way if we push a model to the hub, the inference widget will work.
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
model_args.feature_extractor_name or model_args.model_name_or_path,
|
||||
cache_dir=model_args.cache_dir,
|
||||
revision=model_args.model_revision,
|
||||
use_auth_token=True if model_args.use_auth_token else None,
|
||||
size=data_args.image_size,
|
||||
image_mean=normalize.mean,
|
||||
image_std=normalize.std,
|
||||
)
|
||||
|
||||
# Define torchvision transforms to be applied to each image.
|
||||
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||
_train_transforms = Compose(
|
||||
[
|
||||
RandomResizedCrop(feature_extractor.size),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize(feature_extractor.size),
|
||||
CenterCrop(feature_extractor.size),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def train_transforms(example_batch):
|
||||
"""Apply _train_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||
return example_batch
|
||||
|
||||
def val_transforms(example_batch):
|
||||
"""Apply _val_transforms across a batch."""
|
||||
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||
return example_batch
|
||||
|
||||
if training_args.do_train:
|
||||
if "train" not in ds:
|
||||
raise ValueError("--do_train requires a train dataset")
|
||||
|
||||
Reference in New Issue
Block a user