add vit tf doctest with @add_code_sample_docstrings (#16636)
* add vit tf doctest with @add_code_sample_docstrings * add labels string back in Co-authored-by: Johannes Kolbe <johannes.kolbe@tech.better.team>
This commit is contained in:
@@ -33,14 +33,23 @@ from ...modeling_tf_utils import (
|
|||||||
unpack_inputs,
|
unpack_inputs,
|
||||||
)
|
)
|
||||||
from ...tf_utils import shape_list
|
from ...tf_utils import shape_list
|
||||||
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
|
||||||
from .configuration_vit import ViTConfig
|
from .configuration_vit import ViTConfig
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
# General docstring
|
||||||
_CONFIG_FOR_DOC = "ViTConfig"
|
_CONFIG_FOR_DOC = "ViTConfig"
|
||||||
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
|
_FEAT_EXTRACTOR_FOR_DOC = "ViTFeatureExtractor"
|
||||||
|
|
||||||
|
# Base docstring
|
||||||
|
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
|
||||||
|
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
|
||||||
|
|
||||||
|
# Image classification docstring
|
||||||
|
_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
|
||||||
|
_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
|
||||||
|
|
||||||
|
|
||||||
# Inspired by
|
# Inspired by
|
||||||
@@ -645,7 +654,14 @@ class TFViTModel(TFViTPreTrainedModel):
|
|||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
|
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||||
|
output_type=TFBaseModelOutputWithPooling,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
modality="vision",
|
||||||
|
expected_output=_EXPECTED_OUTPUT_SHAPE,
|
||||||
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[TFModelInputType] = None,
|
pixel_values: Optional[TFModelInputType] = None,
|
||||||
@@ -656,26 +672,6 @@ class TFViTModel(TFViTPreTrainedModel):
|
|||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import ViTFeatureExtractor, TFViTModel
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
|
|
||||||
>>> model = TFViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
|
||||||
|
|
||||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
|
||||||
>>> outputs = model(**inputs)
|
|
||||||
>>> last_hidden_states = outputs.last_hidden_state
|
|
||||||
```"""
|
|
||||||
|
|
||||||
outputs = self.vit(
|
outputs = self.vit(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@@ -744,7 +740,13 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
|||||||
|
|
||||||
@unpack_inputs
|
@unpack_inputs
|
||||||
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
@add_code_sample_docstrings(
|
||||||
|
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
|
||||||
|
checkpoint=_IMAGE_CLASS_CHECKPOINT,
|
||||||
|
output_type=TFSequenceClassifierOutput,
|
||||||
|
config_class=_CONFIG_FOR_DOC,
|
||||||
|
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
|
||||||
|
)
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
pixel_values: Optional[TFModelInputType] = None,
|
pixel_values: Optional[TFModelInputType] = None,
|
||||||
@@ -761,30 +763,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
|
|||||||
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
|
||||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
||||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||||
|
"""
|
||||||
Returns:
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
|
|
||||||
>>> import tensorflow as tf
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
|
|
||||||
>>> model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
|
|
||||||
|
|
||||||
>>> inputs = feature_extractor(images=image, return_tensors="tf")
|
|
||||||
>>> outputs = model(**inputs)
|
|
||||||
>>> logits = outputs.logits
|
|
||||||
>>> # model predicts one of the 1000 ImageNet classes
|
|
||||||
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
|
|
||||||
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
|
|
||||||
```"""
|
|
||||||
|
|
||||||
outputs = self.vit(
|
outputs = self.vit(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ src/transformers/models/van/modeling_van.py
|
|||||||
src/transformers/models/vilt/modeling_vilt.py
|
src/transformers/models/vilt/modeling_vilt.py
|
||||||
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
|
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
|
||||||
src/transformers/models/vit/modeling_vit.py
|
src/transformers/models/vit/modeling_vit.py
|
||||||
|
src/transformers/models/vit/modeling_tf_vit.py
|
||||||
src/transformers/models/vit_mae/modeling_vit_mae.py
|
src/transformers/models/vit_mae/modeling_vit_mae.py
|
||||||
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
||||||
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
|
||||||
|
|||||||
Reference in New Issue
Block a user