From 9efad4efedd43de6935724b72a481bb8f14d35e3 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Wed, 21 Dec 2022 10:09:50 +0100 Subject: [PATCH] [Swin2SR] Add doc tests (#20829) * Fix doc tests * Use Auto API * Apply suggestion * Revert "Apply suggestion" This reverts commit cd9507a86644b4877c3e4a3d6c2d5919d9272dd7. Co-authored-by: Niels Rogge Co-authored-by: Niels Rogge --- .../models/swin2sr/modeling_swin2sr.py | 37 ++++++++++++------- utils/documentation_tests.txt | 1 + 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 0edc2feea8..b40ce8868c 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -43,11 +43,11 @@ logger = logging.get_logger(__name__) # General docstring _CONFIG_FOR_DOC = "Swin2SRConfig" -_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor" +_FEAT_EXTRACTOR_FOR_DOC = "AutoImageProcessor" # Base docstring -_CHECKPOINT_FOR_DOC = "caidas/swin2sr-classicalsr-x2-64" -_EXPECTED_OUTPUT_SHAPE = [1, 64, 768] +_CHECKPOINT_FOR_DOC = "caidas/swin2SR-classical-sr-x2-64" +_EXPECTED_OUTPUT_SHAPE = [1, 180, 488, 648] SWIN2SR_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -1141,19 +1141,28 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel): Example: ```python >>> import torch - >>> from transformers import Swin2SRFeatureExtractor, Swin2SRForImageSuperResolution - >>> from datasets import load_dataset + >>> import numpy as np + >>> from PIL import Image + >>> import requests - >>> feature_extractor = Swin2SRFeatureExtractor.from_pretrained("openai/whisper-base") - >>> model = Swin2SRForImageSuperResolution.from_pretrained("openai/whisper-base") + >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features - >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id - >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state - >>> list(last_hidden_state.shape) - [1, 2, 512] + >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") + >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64") + + >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> # prepare image for the model + >>> inputs = processor(image, return_tensors="pt") + + >>> # forward pass + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() + >>> output = np.moveaxis(output, source=0, destination=-1) + >>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 + >>> # you can visualize `output` with `Image.fromarray` ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index d0f097ac5f..b0fabcfd24 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -166,6 +166,7 @@ src/transformers/models/segformer/modeling_tf_segformer.py src/transformers/models/squeezebert/configuration_squeezebert.py src/transformers/models/swin/configuration_swin.py src/transformers/models/swin/modeling_swin.py +src/transformers/models/swin2sr/modeling_swin2sr.py src/transformers/models/swinv2/configuration_swinv2.py src/transformers/models/table_transformer/modeling_table_transformer.py src/transformers/models/time_series_transformer/configuration_time_series_transformer.py