✨ update image classification example (#13824)
* ✨ update image classification example * 📌 update reqs
This commit is contained in:
@@ -1,2 +1,3 @@
|
|||||||
torch>=1.9.0
|
torch>=1.5.0
|
||||||
torchvision>=0.10.0
|
torchvision>=0.6.0
|
||||||
|
datasets>=1.8.0
|
||||||
@@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
|||||||
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
||||||
check_min_version("4.12.0.dev0")
|
check_min_version("4.12.0.dev0")
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
||||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||||
@@ -102,7 +102,6 @@ class DataTrainingArguments:
|
|||||||
"value if set."
|
"value if set."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
data_files = dict()
|
data_files = dict()
|
||||||
@@ -210,35 +209,6 @@ def main():
|
|||||||
task="image-classification",
|
task="image-classification",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define torchvision transforms to be applied to each image.
|
|
||||||
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
||||||
_train_transforms = Compose(
|
|
||||||
[
|
|
||||||
RandomResizedCrop(data_args.image_size),
|
|
||||||
RandomHorizontalFlip(),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
_val_transforms = Compose(
|
|
||||||
[
|
|
||||||
Resize(data_args.image_size),
|
|
||||||
CenterCrop(data_args.image_size),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def train_transforms(example_batch):
|
|
||||||
"""Apply _train_transforms across a batch."""
|
|
||||||
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
|
||||||
return example_batch
|
|
||||||
|
|
||||||
def val_transforms(example_batch):
|
|
||||||
"""Apply _val_transforms across a batch."""
|
|
||||||
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
|
||||||
return example_batch
|
|
||||||
|
|
||||||
# If we don't have a validation split, split off a percentage of train as validation.
|
# If we don't have a validation split, split off a percentage of train as validation.
|
||||||
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
|
data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split
|
||||||
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
|
||||||
@@ -281,20 +251,42 @@ def main():
|
|||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
)
|
)
|
||||||
# NOTE - We aren't directly using this feature extractor since we defined custom transforms above.
|
|
||||||
# We initialize this instance below and pass it to Trainer to ensure that the feature extraction
|
|
||||||
# config, preprocessor_config.json, is included in output directories.
|
|
||||||
# This way if we push a model to the hub, the inference widget will work.
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
model_args.feature_extractor_name or model_args.model_name_or_path,
|
model_args.feature_extractor_name or model_args.model_name_or_path,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
use_auth_token=True if model_args.use_auth_token else None,
|
use_auth_token=True if model_args.use_auth_token else None,
|
||||||
size=data_args.image_size,
|
|
||||||
image_mean=normalize.mean,
|
|
||||||
image_std=normalize.std,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Define torchvision transforms to be applied to each image.
|
||||||
|
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
|
||||||
|
_train_transforms = Compose(
|
||||||
|
[
|
||||||
|
RandomResizedCrop(feature_extractor.size),
|
||||||
|
RandomHorizontalFlip(),
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_val_transforms = Compose(
|
||||||
|
[
|
||||||
|
Resize(feature_extractor.size),
|
||||||
|
CenterCrop(feature_extractor.size),
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_transforms(example_batch):
|
||||||
|
"""Apply _train_transforms across a batch."""
|
||||||
|
example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||||
|
return example_batch
|
||||||
|
|
||||||
|
def val_transforms(example_batch):
|
||||||
|
"""Apply _val_transforms across a batch."""
|
||||||
|
example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]]
|
||||||
|
return example_batch
|
||||||
|
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
if "train" not in ds:
|
if "train" not in ds:
|
||||||
raise ValueError("--do_train requires a train dataset")
|
raise ValueError("--do_train requires a train dataset")
|
||||||
|
|||||||
Reference in New Issue
Block a user