[Fix doc example] Fix 2 PyTorch Vilt docstring examples (#16076)
* fix 2 pytorch vilt docstring examples * add vilt to doctest list file * remove device Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -933,6 +933,7 @@ class ViltForMaskedLM(ViltPreTrainedModel):
|
|||||||
>>> import requests
|
>>> import requests
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import re
|
>>> import re
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
@@ -954,9 +955,9 @@ class ViltForMaskedLM(ViltPreTrainedModel):
|
|||||||
>>> with torch.no_grad():
|
>>> with torch.no_grad():
|
||||||
... for i in range(tl):
|
... for i in range(tl):
|
||||||
... encoded = processor.tokenizer(inferred_token)
|
... encoded = processor.tokenizer(inferred_token)
|
||||||
... input_ids = torch.tensor(encoded.input_ids).to(device)
|
... input_ids = torch.tensor(encoded.input_ids)
|
||||||
... encoded = encoded["input_ids"][0][1:-1]
|
... encoded = encoded["input_ids"][0][1:-1]
|
||||||
... outputs = model(input_ids=input_ids, pixel_values=pixel_values)
|
... outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
|
||||||
... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
|
... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
|
||||||
... # only take into account text features (minus CLS and SEP token)
|
... # only take into account text features (minus CLS and SEP token)
|
||||||
... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
|
... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
|
||||||
@@ -969,7 +970,8 @@ class ViltForMaskedLM(ViltPreTrainedModel):
|
|||||||
|
|
||||||
>>> selected_token = ""
|
>>> selected_token = ""
|
||||||
>>> encoded = processor.tokenizer(inferred_token)
|
>>> encoded = processor.tokenizer(inferred_token)
|
||||||
>>> processor.decode(encoded.input_ids[0], skip_special_tokens=True)
|
>>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
|
||||||
|
>>> print(output)
|
||||||
a bunch of cats laying on a couch.
|
a bunch of cats laying on a couch.
|
||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1215,12 +1217,10 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
|
|||||||
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
|
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
|
||||||
>>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
|
>>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
|
||||||
|
|
||||||
>>> # prepare inputs
|
|
||||||
>>> encoding = processor(image, text, return_tensors="pt")
|
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> scores = dict()
|
>>> scores = dict()
|
||||||
>>> for text in texts:
|
>>> for text in texts:
|
||||||
|
... # prepare inputs
|
||||||
... encoding = processor(image, text, return_tensors="pt")
|
... encoding = processor(image, text, return_tensors="pt")
|
||||||
... outputs = model(**encoding)
|
... outputs = model(**encoding)
|
||||||
... scores[text] = outputs.logits[0, :].item()
|
... scores[text] = outputs.logits[0, :].item()
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ src/transformers/models/swin/modeling_swin.py
|
|||||||
src/transformers/models/convnext/modeling_convnext.py
|
src/transformers/models/convnext/modeling_convnext.py
|
||||||
src/transformers/models/poolformer/modeling_poolformer.py
|
src/transformers/models/poolformer/modeling_poolformer.py
|
||||||
src/transformers/models/vit_mae/modeling_vit_mae.py
|
src/transformers/models/vit_mae/modeling_vit_mae.py
|
||||||
|
src/transformers/models/vilt/modeling_vilt.py
|
||||||
src/transformers/models/van/modeling_van.py
|
src/transformers/models/van/modeling_van.py
|
||||||
src/transformers/models/segformer/modeling_segformer.py
|
src/transformers/models/segformer/modeling_segformer.py
|
||||||
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
|
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
|
||||||
|
|||||||
Reference in New Issue
Block a user