From 9db2eebbe2ba692757f2d141f1078bb54fb9b323 Mon Sep 17 00:00:00 2001 From: Johannes Kolbe <2843485+johko@users.noreply.github.com> Date: Fri, 8 Apr 2022 13:31:38 +0200 Subject: [PATCH] add vit tf doctest with @add_code_sample_docstrings (#16636) * add vit tf doctest with @add_code_sample_docstrings * add labels string back in Co-authored-by: Johannes Kolbe --- .../models/vit/modeling_tf_vit.py | 75 +++++++------------ utils/documentation_tests.txt | 1 + 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index cbf935f4f7..e3c039ca83 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -33,14 +33,23 @@ from ...modeling_tf_utils import ( unpack_inputs, ) from ...tf_utils import shape_list -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_vit import ViTConfig logger = logging.get_logger(__name__) +# General docstring _CONFIG_FOR_DOC = "ViTConfig" -_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224" +_FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor" + +# Base docstring +_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" # Inspired by @@ -645,7 +654,14 @@ class TFViTModel(TFViTPreTrainedModel): @unpack_inputs @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) def call( self, pixel_values: Optional[TFModelInputType] = None, @@ -656,26 +672,6 @@ class TFViTModel(TFViTPreTrainedModel): return_dict: Optional[bool] = None, training: bool = False, ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: - r""" - Returns: - - Examples: - - ```python - >>> from transformers import ViTFeatureExtractor, TFViTModel - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") - >>> model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k") - - >>> inputs = feature_extractor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> last_hidden_states = outputs.last_hidden_state - ```""" outputs = self.vit( pixel_values=pixel_values, @@ -744,7 +740,13 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification @unpack_inputs @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + processor_class=_FEAT_EXTRACTOR_FOR_DOC, + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) def call( self, pixel_values: Optional[TFModelInputType] = None, @@ -761,30 +763,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification Labels for computing the image classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - - Returns: - - Examples: - - ```python - >>> from transformers import ViTFeatureExtractor, TFViTForImageClassification - >>> import tensorflow as tf - >>> from PIL import Image - >>> import requests - - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") - >>> model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224") - - >>> inputs = feature_extractor(images=image, return_tensors="tf") - >>> outputs = model(**inputs) - >>> logits = outputs.logits - >>> # model predicts one of the 1000 ImageNet classes - >>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0] - >>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)]) - ```""" + """ outputs = self.vit( pixel_values=pixel_values, diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index ef4c3b55f5..8fe33b240e 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -36,6 +36,7 @@ src/transformers/models/van/modeling_van.py src/transformers/models/vilt/modeling_vilt.py src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py src/transformers/models/vit/modeling_vit.py +src/transformers/models/vit/modeling_tf_vit.py src/transformers/models/vit_mae/modeling_vit_mae.py src/transformers/models/wav2vec2/modeling_wav2vec2.py src/transformers/models/wav2vec2/tokenization_wav2vec2.py