Fix test and docs (#14399)
This commit is contained in:
@@ -795,6 +795,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
||||
Examples::
|
||||
|
||||
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
|
||||
>>> import tensorflow as tf
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
@@ -809,7 +810,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
||||
>>> logits = outputs.logits
|
||||
>>> # model predicts one of the 1000 ImageNet classes
|
||||
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
||||
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
||||
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
||||
"""
|
||||
inputs = input_processing(
|
||||
func=self.call,
|
||||
|
||||
@@ -371,7 +371,7 @@ class TFViTModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_inference_image_classification_head(self):
|
||||
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
|
||||
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
||||
|
||||
feature_extractor = self.default_feature_extractor
|
||||
image = prepare_img()
|
||||
|
||||
Reference in New Issue
Block a user