diff --git a/examples/pytorch/image-classification/run_image_classification.py b/examples/pytorch/image-classification/run_image_classification.py old mode 100644 new mode 100755 index 27a81f0094..d009297541 --- a/examples/pytorch/image-classification/run_image_classification.py +++ b/examples/pytorch/image-classification/run_image_classification.py @@ -28,6 +28,7 @@ from PIL import Image from torchvision.transforms import ( CenterCrop, Compose, + Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, @@ -325,7 +326,11 @@ def main(): size = image_processor.size["shortest_edge"] else: size = (image_processor.size["height"], image_processor.size["width"]) - normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + normalize = ( + Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") + else Lambda(lambda x: x) + ) _train_transforms = Compose( [ RandomResizedCrop(size), diff --git a/examples/pytorch/image-classification/run_image_classification_no_trainer.py b/examples/pytorch/image-classification/run_image_classification_no_trainer.py index 3e38a3e79a..52b5fabd89 100644 --- a/examples/pytorch/image-classification/run_image_classification_no_trainer.py +++ b/examples/pytorch/image-classification/run_image_classification_no_trainer.py @@ -32,6 +32,7 @@ from torch.utils.data import DataLoader from torchvision.transforms import ( CenterCrop, Compose, + Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, @@ -331,7 +332,11 @@ def main(): size = image_processor.size["shortest_edge"] else: size = (image_processor.size["height"], image_processor.size["width"]) - normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + normalize = ( + Normalize(mean=image_processor.image_mean, std=image_processor.image_std) + if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std") + else Lambda(lambda x: x) + ) train_transforms = Compose( [ RandomResizedCrop(size),