diff --git a/examples/pytorch/image-classification/requirements.txt b/examples/pytorch/image-classification/requirements.txt index 08c6d3bc1d..a789fee85e 100644 --- a/examples/pytorch/image-classification/requirements.txt +++ b/examples/pytorch/image-classification/requirements.txt @@ -1,2 +1,3 @@ -torch>=1.9.0 -torchvision>=0.10.0 \ No newline at end of file +torch>=1.5.0 +torchvision>=0.6.0 +datasets>=1.8.0 \ No newline at end of file diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py index 7a3daf72bc..fe0ccb0fb3 100644 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -56,7 +56,7 @@ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.12.0.dev0") -require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -102,7 +102,6 @@ class DataTrainingArguments: "value if set." }, ) - image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."}) def __post_init__(self): data_files = dict() @@ -210,35 +209,6 @@ def main(): task="image-classification", ) - # Define torchvision transforms to be applied to each image. - normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - _train_transforms = Compose( - [ - RandomResizedCrop(data_args.image_size), - RandomHorizontalFlip(), - ToTensor(), - normalize, - ] - ) - _val_transforms = Compose( - [ - Resize(data_args.image_size), - CenterCrop(data_args.image_size), - ToTensor(), - normalize, - ] - ) - - def train_transforms(example_batch): - """Apply _train_transforms across a batch.""" - example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] - return example_batch - - def val_transforms(example_batch): - """Apply _val_transforms across a batch.""" - example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] - return example_batch - # If we don't have a validation split, split off a percentage of train as validation. data_args.train_val_split = None if "validation" in ds.keys() else data_args.train_val_split if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: @@ -281,20 +251,42 @@ def main(): revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) - # NOTE - We aren't directly using this feature extractor since we defined custom transforms above. - # We initialize this instance below and pass it to Trainer to ensure that the feature extraction - # config, preprocessor_config.json, is included in output directories. - # This way if we push a model to the hub, the inference widget will work. feature_extractor = AutoFeatureExtractor.from_pretrained( model_args.feature_extractor_name or model_args.model_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, - size=data_args.image_size, - image_mean=normalize.mean, - image_std=normalize.std, ) + # Define torchvision transforms to be applied to each image. + normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) + _train_transforms = Compose( + [ + RandomResizedCrop(feature_extractor.size), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] + ) + _val_transforms = Compose( + [ + Resize(feature_extractor.size), + CenterCrop(feature_extractor.size), + ToTensor(), + normalize, + ] + ) + + def train_transforms(example_batch): + """Apply _train_transforms across a batch.""" + example_batch["pixel_values"] = [_train_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] + return example_batch + + def val_transforms(example_batch): + """Apply _val_transforms across a batch.""" + example_batch["pixel_values"] = [_val_transforms(pil_loader(f)) for f in example_batch["image_file_path"]] + return example_batch + if training_args.do_train: if "train" not in ds: raise ValueError("--do_train requires a train dataset")