From ed31ab3f103e7b3b0b08658baa9b36174517fdd2 Mon Sep 17 00:00:00 2001 From: Arnaud Stiegler Date: Tue, 29 Mar 2022 10:19:06 -0400 Subject: [PATCH] Adding DocTest to TrOCR (#16398) * docstring still WIP | adding to documentation_tests * clean version | passes tests * adding to documentation_test * adding forward for training pass * make fixup applied * address comments * fix doctest * apply make fixup * remove additional blank * fix file to have correct split for prepare_for_doc_test * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * address comments * changing text | adding loss check | make fixup * make fixup * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * Update src/transformers/models/trocr/modeling_trocr.py Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> * make fixup Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> --- .../models/trocr/modeling_trocr.py | 46 +++++++++++++++++-- utils/documentation_tests.txt | 1 + 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 8a26739d87..75e015f988 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -597,8 +597,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel): If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all ``decoder_input_ids``` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` - of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of + shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. @@ -891,13 +891,49 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): Example: ```python - >>> from transformers import VisionEncoderDecoderModel, TrOCRForCausalLM, ViTModel, TrOCRConfig, ViTConfig + >>> from transformers import ( + ... TrOCRConfig, + ... TrOCRProcessor, + ... TrOCRForCausalLM, + ... ViTConfig, + ... ViTModel, + ... VisionEncoderDecoderModel, + ... ) + >>> import requests + >>> from PIL import Image + >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel + >>> # init vision2text model with random weights >>> encoder = ViTModel(ViTConfig()) >>> decoder = TrOCRForCausalLM(TrOCRConfig()) - # init vision2text model - >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder) + + >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel` + >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") + >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") + + >>> # load image from the IAM dataset + >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + >>> pixel_values = processor(image, return_tensors="pt").pixel_values + >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a" + + >>> # training + >>> model.config.decoder_start_token_id = processor.tokenizer.cls_token_id + >>> model.config.pad_token_id = processor.tokenizer.pad_token_id + >>> model.config.vocab_size = model.config.decoder.vocab_size + + >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids + >>> outputs = model(pixel_values, labels=labels) + >>> loss = outputs.loss + >>> round(loss.item(), 2) + 5.30 + + >>> # inference + >>> generated_ids = model.generate(pixel_values) + >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> generated_text + 'industry, " Mr. Brown commented icily. " Let us have a' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions diff --git a/utils/documentation_tests.txt b/utils/documentation_tests.txt index 5e2435ef46..372e63ad23 100644 --- a/utils/documentation_tests.txt +++ b/utils/documentation_tests.txt @@ -39,6 +39,7 @@ src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.p src/transformers/models/speech_to_text/modeling_speech_to_text.py src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py src/transformers/models/swin/modeling_swin.py +src/transformers/models/trocr/modeling_trocr.py src/transformers/models/unispeech/modeling_unispeech.py src/transformers/models/unispeech_sat/modeling_unispeech_sat.py src/transformers/models/van/modeling_van.py