From 2237127a6cdaa3459677878a51fe7ed363e6556f Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 17 Apr 2023 16:11:52 +0100 Subject: [PATCH] Fix sneaky torch dependency in TF example (#22804) --- .../image-classification/run_image_classification.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/examples/tensorflow/image-classification/run_image_classification.py b/examples/tensorflow/image-classification/run_image_classification.py index 105d5c8056..ee6ebfb469 100644 --- a/examples/tensorflow/image-classification/run_image_classification.py +++ b/examples/tensorflow/image-classification/run_image_classification.py @@ -34,7 +34,7 @@ from PIL import Image import transformers from transformers import ( - MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, AutoConfig, AutoImageProcessor, DefaultDataCollator, @@ -58,7 +58,7 @@ check_min_version("4.29.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") -MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) +MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -262,11 +262,6 @@ def main(): transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() - # Log on each process the small summary: - logger.warning( - f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" - + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - ) logger.info(f"Training/evaluation parameters {training_args}") # region Dataset and labels