Fix test and docs (#14399)
This commit is contained in:
@@ -795,6 +795,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
|||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
|
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
|
||||||
|
>>> import tensorflow as tf
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import requests
|
>>> import requests
|
||||||
|
|
||||||
@@ -809,7 +810,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
|||||||
>>> logits = outputs.logits
|
>>> logits = outputs.logits
|
||||||
>>> # model predicts one of the 1000 ImageNet classes
|
>>> # model predicts one of the 1000 ImageNet classes
|
||||||
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
||||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
||||||
"""
|
"""
|
||||||
inputs = input_processing(
|
inputs = input_processing(
|
||||||
func=self.call,
|
func=self.call,
|
||||||
|
|||||||
@@ -371,7 +371,7 @@ class TFViTModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_image_classification_head(self):
|
def test_inference_image_classification_head(self):
|
||||||
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
|
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
||||||
|
|
||||||
feature_extractor = self.default_feature_extractor
|
feature_extractor = self.default_feature_extractor
|
||||||
image = prepare_img()
|
image = prepare_img()
|
||||||
|
|||||||
Reference in New Issue
Block a user