[Swin2SR] Add doc tests (#20829)
* Fix doc tests * Use Auto API * Apply suggestion * Revert "Apply suggestion" This reverts commit cd9507a86644b4877c3e4a3d6c2d5919d9272dd7. Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local> Co-authored-by: Niels Rogge <nielsrogge@Nielss-MBP.localdomain>
This commit is contained in:
@@ -43,11 +43,11 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
# General docstring
|
||||
_CONFIG_FOR_DOC = "Swin2SRConfig"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
|
||||
_FEAT_EXTRACTOR_FOR_DOC = "AutoImageProcessor"
|
||||
|
||||
# Base docstring
|
||||
_CHECKPOINT_FOR_DOC = "caidas/swin2sr-classicalsr-x2-64"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 64, 768]
|
||||
_CHECKPOINT_FOR_DOC = "caidas/swin2SR-classical-sr-x2-64"
|
||||
_EXPECTED_OUTPUT_SHAPE = [1, 180, 488, 648]
|
||||
|
||||
|
||||
SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
@@ -1141,19 +1141,28 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import Swin2SRFeatureExtractor, Swin2SRForImageSuperResolution
|
||||
>>> from datasets import load_dataset
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
|
||||
>>> feature_extractor = Swin2SRFeatureExtractor.from_pretrained("openai/whisper-base")
|
||||
>>> model = Swin2SRForImageSuperResolution.from_pretrained("openai/whisper-base")
|
||||
>>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
|
||||
|
||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
|
||||
>>> input_features = inputs.input_features
|
||||
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
|
||||
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
>>> list(last_hidden_state.shape)
|
||||
[1, 2, 512]
|
||||
>>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
|
||||
>>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
|
||||
|
||||
>>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> # prepare image for the model
|
||||
>>> inputs = processor(image, return_tensors="pt")
|
||||
|
||||
>>> # forward pass
|
||||
>>> with torch.no_grad():
|
||||
... outputs = model(**inputs)
|
||||
|
||||
>>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
>>> output = np.moveaxis(output, source=0, destination=-1)
|
||||
>>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
|
||||
>>> # you can visualize `output` with `Image.fromarray`
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py
|
||||
src/transformers/models/squeezebert/configuration_squeezebert.py
|
||||
src/transformers/models/swin/configuration_swin.py
|
||||
src/transformers/models/swin/modeling_swin.py
|
||||
src/transformers/models/swin2sr/modeling_swin2sr.py
|
||||
src/transformers/models/swinv2/configuration_swinv2.py
|
||||
src/transformers/models/table_transformer/modeling_table_transformer.py
|
||||
src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
|
||||
|
||||
Reference in New Issue
Block a user