Fix test and docs (#14399)

This commit is contained in:
NielsRogge
2021-11-15 17:35:33 +01:00
committed by GitHub
parent 4ce74edf51
commit 74e6111ba7
2 changed files with 3 additions and 2 deletions

View File

@@ -795,6 +795,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
Examples::
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
>>> import tensorflow as tf
>>> from PIL import Image
>>> import requests
@@ -809,7 +810,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
"""
inputs = input_processing(
func=self.call,

View File

@@ -371,7 +371,7 @@ class TFViTModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_image_classification_head(self):
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
feature_extractor = self.default_feature_extractor
image = prepare_img()