diff --git a/examples/pytorch/object-detection/run_object_detection.py b/examples/pytorch/object-detection/run_object_detection.py index 8d722f4d5d..095b41a6a4 100644 --- a/examples/pytorch/object-detection/run_object_detection.py +++ b/examples/pytorch/object-detection/run_object_detection.py @@ -271,6 +271,10 @@ class DataTrainingArguments: ) }, ) + use_fast: Optional[bool] = field( + default=True, + metadata={"help": "Use a fast torchvision-base image processor if it is supported for a given model."}, + ) @dataclass @@ -427,6 +431,7 @@ def main(): size={"max_height": data_args.image_square_size, "max_width": data_args.image_square_size}, do_pad=True, pad_size={"height": data_args.image_square_size, "width": data_args.image_square_size}, + use_fast=data_args.use_fast, **common_pretrained_args, ) diff --git a/examples/pytorch/object-detection/run_object_detection_no_trainer.py b/examples/pytorch/object-detection/run_object_detection_no_trainer.py index dbfcb3fd97..b7ca051949 100644 --- a/examples/pytorch/object-detection/run_object_detection_no_trainer.py +++ b/examples/pytorch/object-detection/run_object_detection_no_trainer.py @@ -256,6 +256,12 @@ def parse_args(): default=1333, help="Image longest size will be resized to this value, then image will be padded to square.", ) + parser.add_argument( + "--use_fast", + type=bool, + default=True, + help="Use a fast torchvision-base image processor if it is supported for a given model.", + ) parser.add_argument( "--cache_dir", type=str, @@ -482,6 +488,7 @@ def main(): size={"max_height": args.image_square_size, "max_width": args.image_square_size}, do_pad=True, pad_size={"height": args.image_square_size, "width": args.image_square_size}, + use_fast=args.use_fast, **common_pretrained_args, )