From 515ed3ad2a11a6b0cd9800b2ad4d3b313fdaea8c Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Thu, 20 Jan 2022 21:51:51 +0100 Subject: [PATCH] Fix doc examples (#15257) --- docs/source/model_doc/trocr.mdx | 3 +- src/transformers/models/vilt/modeling_vilt.py | 46 +++++++++++++++---- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/docs/source/model_doc/trocr.mdx b/docs/source/model_doc/trocr.mdx index 494895c1a8..08de107e43 100644 --- a/docs/source/model_doc/trocr.mdx +++ b/docs/source/model_doc/trocr.mdx @@ -70,7 +70,8 @@ into a single instance to both extract the input features and decode the predict >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") ->>> # load image from the IAM dataset url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" +>>> # load image from the IAM dataset +>>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") >>> pixel_values = processor(image, return_tensors="pt").pixel_values diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 28343fe791..d51f9f5e3a 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -42,10 +42,10 @@ from .configuration_vilt import ViltConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "ViltConfig" -_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm-itm" +_CHECKPOINT_FOR_DOC = "dandelin/vilt-b32-mlm" VILT_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "dandelin/vilt-b32-mlm-itm", + "dandelin/vilt-b32-mlm", # See all ViLT models at https://huggingface.co/models?filter=vilt ] @@ -775,17 +775,19 @@ class ViltModel(ViltPreTrainedModel): Examples: ```python - >>> from transformers import ViltFeatureExtractor, ViltModel + >>> from transformers import ViltProcessor, ViltModel >>> from PIL import Image >>> import requests + >>> # prepare image and text >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) + >>> text = "hello world" - >>> feature_extractor = ViltFeatureExtractor.from_pretrained("dandelin/vilt-b32-mlm-itm") - >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm-itm") + >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm") + >>> model = ViltModel.from_pretrained("dandelin/vilt-b32-mlm") - >>> inputs = feature_extractor(images=image, return_tensors="pt") + >>> inputs = processor(image, text, return_tensors="pt") >>> outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state ```""" @@ -930,10 +932,11 @@ class ViltForMaskedLM(ViltPreTrainedModel): >>> from transformers import ViltProcessor, ViltForMaskedLM >>> import requests >>> from PIL import Image + >>> import re >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> text = "How many cats are there?" + >>> text = "a bunch of [MASK] laying on a [MASK]." >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-mlm") >>> model = ViltForMaskedLM.from_pretrained("dandelin/vilt-b32-mlm") @@ -943,7 +946,31 @@ class ViltForMaskedLM(ViltPreTrainedModel): >>> # forward pass >>> outputs = model(**encoding) - >>> logits = outputs.logits + + >>> tl = len(re.findall("\[MASK\]", text)) + >>> inferred_token = [text] + + >>> # gradually fill in the MASK tokens, one by one + >>> with torch.no_grad(): + ... for i in range(tl): + ... encoded = processor.tokenizer(inferred_token) + ... input_ids = torch.tensor(encoded.input_ids).to(device) + ... encoded = encoded["input_ids"][0][1:-1] + ... outputs = model(input_ids=input_ids, pixel_values=pixel_values) + ... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size) + ... # only take into account text features (minus CLS and SEP token) + ... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :] + ... mlm_values, mlm_ids = mlm_logits.softmax(dim=-1).max(dim=-1) + ... # only take into account text + ... mlm_values[torch.tensor(encoded) != 103] = 0 + ... select = mlm_values.argmax().item() + ... encoded[select] = mlm_ids[select].item() + ... inferred_token = [processor.decode(encoded)] + + >>> selected_token = "" + >>> encoded = processor.tokenizer(inferred_token) + >>> processor.decode(encoded.input_ids[0], skip_special_tokens=True) + a bunch of cats laying on a couch. ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1093,6 +1120,7 @@ class ViltForQuestionAnswering(ViltPreTrainedModel): >>> logits = outputs.logits >>> idx = logits.argmax(-1).item() >>> print("Predicted answer:", model.config.id2label[idx]) + Predicted answer: 2 ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1297,13 +1325,13 @@ class ViltForImagesAndTextClassification(ViltPreTrainedModel): >>> # prepare inputs >>> encoding = processor([image1, image2], text, return_tensors="pt") - >>> pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1) >>> # forward pass >>> outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0)) >>> logits = outputs.logits >>> idx = logits.argmax(-1).item() >>> print("Predicted answer:", model.config.id2label[idx]) + Predicted answer: True ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (