From e5bc438cc8bae1cfd696b8f37e2a171fdb0fc806 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 15 Mar 2022 13:35:02 +0100 Subject: [PATCH] [Fix doc example] Fix 2 PyTorch Vilt docstring examples (#16076) * fix 2 pytorch vilt docstring examples * add vilt to doctest list file * remove device Co-authored-by: ydshieh --- src/transformers/models/vilt/modeling_vilt.py | 12 ++++++------ utils/documentation_tests.txt | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index d51f9f5e3a..b96846574a 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -933,6 +933,7 @@ class ViltForMaskedLM(ViltPreTrainedModel): >>> import requests >>> from PIL import Image >>> import re + >>> import torch >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) @@ -954,9 +955,9 @@ class ViltForMaskedLM(ViltPreTrainedModel): >>> with torch.no_grad(): ... for i in range(tl): ... encoded = processor.tokenizer(inferred_token) - ... input_ids = torch.tensor(encoded.input_ids).to(device) + ... input_ids = torch.tensor(encoded.input_ids) ... encoded = encoded["input_ids"][0][1:-1] - ... outputs = model(input_ids=input_ids, pixel_values=pixel_values) + ... outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values) ... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size) ... # only take into account text features (minus CLS and SEP token) ... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :] @@ -969,7 +970,8 @@ class ViltForMaskedLM(ViltPreTrainedModel): >>> selected_token = "" >>> encoded = processor.tokenizer(inferred_token) - >>> processor.decode(encoded.input_ids[0], skip_special_tokens=True) + >>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True) + >>> print(output) a bunch of cats laying on a couch. ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1215,12 +1217,10 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel): >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco") >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco") - >>> # prepare inputs - >>> encoding = processor(image, text, return_tensors="pt") - >>> # forward pass >>> scores = dict() >>> for text in texts: + ... # prepare inputs ... encoding = processor(image, text, return_tensors="pt") ... outputs = model(**encoding) ... scores[text] = outputs.logits[0, :].item() diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 2413f0c813..c425e5fbd7 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -18,6 +18,7 @@ src/transformers/models/swin/modeling_swin.py src/transformers/models/convnext/modeling_convnext.py src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/vit_mae/modeling_vit_mae.py +src/transformers/models/vilt/modeling_vilt.py src/transformers/models/van/modeling_van.py src/transformers/models/segformer/modeling_segformer.py src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py