Add Donut (#18488)
* First draft * Improve script * Update script * Make conversion work * Add final_layer_norm attribute to Swin's config * Add DonutProcessor * Convert more models * Improve feature extractor and convert base models * Fix bug * Improve integration tests * Improve integration tests and add model to README * Add doc test * Add feature extractor to docs * Fix integration tests * Remove register_buffer * Fix toctree and add missing attribute * Add DonutSwin * Make conversion script work * Improve conversion script * Address comment * Fix bug * Fix another bug * Remove deprecated method from docs * Make Swin and Swinv2 untouched * Fix code examples * Fix processor * Update model_type to donut-swin * Add feature extractor tests, add token2json method, improve feature extractor * Fix failing tests, remove integration test * Add do_thumbnail for consistency * Improve code examples * Add code example for document parsing * Add DonutSwin to MODEL_NAMES_MAPPING * Add model to appropriate place in toctree * Update namespace to appropriate organization Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -13,14 +13,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
from packaging import version
|
||||
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, to_2tuple, torch_device
|
||||
from transformers import DonutProcessor, TrOCRProcessor
|
||||
from transformers.testing_utils import (
|
||||
require_sentencepiece,
|
||||
require_torch,
|
||||
require_vision,
|
||||
slow,
|
||||
to_2tuple,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||
@@ -54,7 +62,7 @@ if is_vision_available():
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from transformers import TrOCRProcessor, ViTFeatureExtractor
|
||||
from transformers import ViTFeatureExtractor
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -654,8 +662,8 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_handwritten(self):
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to(torch_device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
|
||||
image = Image.open(ds[0]["file"]).convert("RGB")
|
||||
dataset = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
|
||||
image = Image.open(dataset[0]["file"]).convert("RGB")
|
||||
|
||||
processor = self.default_processor
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
@@ -679,8 +687,8 @@ class TrOCRModelIntegrationTest(unittest.TestCase):
|
||||
def test_inference_printed(self):
|
||||
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(torch_device)
|
||||
|
||||
ds = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
|
||||
image = Image.open(ds[1]["file"]).convert("RGB")
|
||||
dataset = load_dataset("hf-internal-testing/fixtures_ocr", split="test")
|
||||
image = Image.open(dataset[1]["file"]).convert("RGB")
|
||||
|
||||
processor = self.default_processor
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
@@ -774,3 +782,197 @@ class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
|
||||
# should produce
|
||||
# ["a cat laying on top of a couch next to another cat"]
|
||||
self.assertEqual(preds, ["a cat laying on top of a couch next to another cat"])
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
class DonutModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_docvqa(self):
|
||||
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
|
||||
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[0]["image"]
|
||||
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
decoder_input_ids = processor.tokenizer(
|
||||
"<s_docvqa>", add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids.to(torch_device)
|
||||
|
||||
# step 1: single forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size([1, 1, 57532])
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([24.2731, -6.4522, 32.4130]).to(torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
# step 2: generation
|
||||
task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
|
||||
question = "When is the coffee break?"
|
||||
prompt = task_prompt.replace("{user_input}", question)
|
||||
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=processor.tokenizer.pad_token_id,
|
||||
eos_token_id=processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
|
||||
# verify generated sequence
|
||||
self.assertEqual(
|
||||
sequence, "<s_question> When is the coffee break?</s_question><s_answer> 11-14 to 11:39 a.m.</s_answer>"
|
||||
)
|
||||
|
||||
# verify scores
|
||||
self.assertEqual(len(outputs.scores), 11)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.scores[0][0, :3], torch.tensor([5.3153, -3.5276, 13.4781], device=torch_device), atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_cordv2(self):
|
||||
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
|
||||
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[2]["image"]
|
||||
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
decoder_input_ids = processor.tokenizer(
|
||||
"<s_cord-v2>", add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids.to(torch_device)
|
||||
|
||||
# step 1: single forward pass
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
# step 2: generation
|
||||
task_prompt = "<s_cord-v2>"
|
||||
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=processor.tokenizer.pad_token_id,
|
||||
eos_token_id=processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
|
||||
# verify generated sequence
|
||||
# fmt: off
|
||||
expected_sequence = "<s_menu><s_nm> CINNAMON SUGAR</s_nm><s_unitprice> 17,000</s_unitprice><s_cnt> 1 x</s_cnt><s_price> 17,000</s_price></s_menu><s_sub_total><s_subtotal_price> 17,000</s_subtotal_price></s_sub_total><s_total><s_total_price> 17,000</s_total_price><s_cashprice> 20,000</s_cashprice><s_changeprice> 3,000</s_changeprice></s_total>" # noqa: E231
|
||||
# fmt: on
|
||||
self.assertEqual(sequence, expected_sequence)
|
||||
|
||||
# verify scores
|
||||
self.assertEqual(len(outputs.scores), 43)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.scores[0][0, :3], torch.tensor([-27.4344, -3.2686, -19.3524], device=torch_device), atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_rvlcdip(self):
|
||||
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
|
||||
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
dataset = load_dataset("hf-internal-testing/example-documents", split="test")
|
||||
image = dataset[1]["image"]
|
||||
|
||||
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
|
||||
|
||||
# step 1: single forward pass
|
||||
decoder_input_ids = processor.tokenizer(
|
||||
"<s_rvlcdip>", add_special_tokens=False, return_tensors="pt"
|
||||
).input_ids.to(torch_device)
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
# verify the logits
|
||||
expected_shape = torch.Size((1, 1, model.decoder.config.vocab_size))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device)
|
||||
self.assertTrue(torch.allclose(logits[0, 0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
# step 2: generation
|
||||
task_prompt = "<s_rvlcdip>"
|
||||
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
||||
decoder_input_ids = decoder_input_ids.to(torch_device)
|
||||
|
||||
outputs = model.generate(
|
||||
pixel_values,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
max_length=model.decoder.config.max_position_embeddings,
|
||||
early_stopping=True,
|
||||
pad_token_id=processor.tokenizer.pad_token_id,
|
||||
eos_token_id=processor.tokenizer.eos_token_id,
|
||||
use_cache=True,
|
||||
num_beams=1,
|
||||
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
sequence = processor.batch_decode(outputs.sequences)[0]
|
||||
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
||||
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
||||
|
||||
# verify generated sequence
|
||||
self.assertEqual(sequence, "<s_class><advertisement/></s_class>")
|
||||
|
||||
# verify scores
|
||||
self.assertEqual(len(outputs.scores), 4)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
outputs.scores[0][0, :3], torch.tensor([-17.6490, -4.8381, -15.7577], device=torch_device), atol=1e-4
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user