From 74e6111ba7b6b9cac0a087c60a74d09a738ab0bc Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 15 Nov 2021 17:35:33 +0100 Subject: [PATCH] Fix test and docs (#14399) --- src/transformers/models/vit/modeling_tf_vit.py | 3 ++- tests/test_modeling_tf_vit.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index 54030b2ee3..8a809bfe8a 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -795,6 +795,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification Examples:: >>> from transformers import ViTFeatureExtractor, TFViTForImageClassification + >>> import tensorflow as tf >>> from PIL import Image >>> import requests @@ -809,7 +810,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification >>> logits = outputs.logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] - >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) + >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) """ inputs = input_processing( func=self.call, diff --git a/tests/test_modeling_tf_vit.py b/tests/test_modeling_tf_vit.py index 1a83b2824a..eb342aa68d 100644 --- a/tests/test_modeling_tf_vit.py +++ b/tests/test_modeling_tf_vit.py @@ -371,7 +371,7 @@ class TFViTModelIntegrationTest(unittest.TestCase): @slow def test_inference_image_classification_head(self): - model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True) + model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224") feature_extractor = self.default_feature_extractor image = prepare_img()