add new model of MGP-STR (#21418)

* add new model of MGP-STR

* fix the check failings

* remove torch and numpy from mgp_tokenization

* remove unused import from modeling_mgp_str

* add test_processing_mgp_str

* rm test_processing_mgp_str.py

* add test_processing_mgp_str

* add test_processing_mgp_str

* add test_processing_mgp_str

* rm test_processing_mgp_str and add softmax outs to model

* rm test_processing_mgp_str and add softmax outs to model

* rewrite the code of mgp-str according to PR suggestions

* rewrite the code of mgp-str according to PR suggestions

* add new model of MGP-STR

* fix the check failings

* remove torch and numpy from mgp_tokenization

* remove unused import from modeling_mgp_str

* add test_processing_mgp_str

* rm test_processing_mgp_str.py

* add test_processing_mgp_str

* add test_processing_mgp_str

* add test_processing_mgp_str

* rm test_processing_mgp_str and add softmax outs to model

* rewrite the code of mgp-str according to PR suggestions

* rewrite the code of mgp-str according to PR suggestions

* remove representation_size from MGPSTRConfig

* reformat configuration_mgp_str.py

* format test_processor_mgp_str.py

* add test for tokenizer and complete model/processer test and model file

* rm Unnecessary tupple in modeling_mgp_str

* reduce hidden_size/layers/label_size in test_model

* add integration tests and change MGPSTR to Mgpstr

* add test for logit values

* reformat test model file

---------

Co-authored-by: yue kun <yuekun.wp@alibaba-inc.com>
This commit is contained in:
wangpeng
2023-03-13 18:11:31 +08:00
committed by GitHub
parent 32e3466d38
commit 102b5ff4a8
28 changed files with 1773 additions and 0 deletions

View File

View File

@@ -0,0 +1,269 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch MGP-STR model. """
import inspect
import unittest
import requests
from transformers import MgpstrConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor
if is_torch_available():
import torch
from torch import nn
from transformers import MgpstrForSceneTextRecognition
if is_vision_available():
from PIL import Image
from transformers import MgpstrProcessor
class MgpstrModelTester:
def __init__(
self,
parent,
is_training=False,
batch_size=13,
image_size=(32, 128),
patch_size=4,
num_channels=3,
max_token_length=27,
num_character_labels=38,
num_bpe_labels=99,
num_wordpiece_labels=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
mlp_ratio=4.0,
patch_embeds_hidden_size=257,
output_hidden_states=None,
):
self.parent = parent
self.is_training = is_training
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.max_token_length = max_token_length
self.num_character_labels = num_character_labels
self.num_bpe_labels = num_bpe_labels
self.num_wordpiece_labels = num_wordpiece_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.mlp_ratio = mlp_ratio
self.patch_embeds_hidden_size = patch_embeds_hidden_size
self.output_hidden_states = output_hidden_states
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
config = self.get_config()
return config, pixel_values
def get_config(self):
return MgpstrConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
max_token_length=self.max_token_length,
num_character_labels=self.num_character_labels,
num_bpe_labels=self.num_bpe_labels,
num_wordpiece_labels=self.num_wordpiece_labels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
mlp_ratio=self.mlp_ratio,
output_hidden_states=self.output_hidden_states,
)
def create_and_check_model(self, config, pixel_values):
model = MgpstrForSceneTextRecognition(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
generated_ids = model(pixel_values)
self.parent.assertEqual(
generated_ids[0][0].shape, (self.batch_size, self.max_token_length, self.num_character_labels)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class MgpstrModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (MgpstrForSceneTextRecognition,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_attention_outputs = False
def setUp(self):
self.model_tester = MgpstrModelTester(self)
self.config_tester = ConfigTester(self, config_class=MgpstrConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="MgpstrModel does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
@unittest.skip(reason="MgpstrModel does not support feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_gradient_checkpointing_backward_compatibility(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
config.gradient_checkpointing = True
model = model_class(config)
self.assertTrue(model.is_gradient_checkpointing)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.patch_embeds_hidden_size, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# override as the `logit_scale` parameter initilization is different for MgpstrModel
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if isinstance(param, (nn.Linear, nn.Conv2d, nn.LayerNorm)):
if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
# We will verify our results on an image from the IIIT-5k dataset
def prepare_img():
url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png"
im = Image.open(requests.get(url, stream=True).raw).convert("RGB")
return im
@require_vision
@require_torch
class MgpstrModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model_name = "alibaba-damo/mgp-str-base"
model = MgpstrForSceneTextRecognition.from_pretrained(model_name).to(torch_device)
processor = MgpstrProcessor.from_pretrained(model_name)
image = prepare_img()
inputs = processor(images=image, return_tensors="pt").pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(inputs)
# verify the logits
self.assertEqual(outputs.logits[0].shape, torch.Size((1, 27, 38)))
out_strs = processor.batch_decode(outputs.logits)
expected_text = "ticket"
self.assertEqual(out_strs["generated_text"][0], expected_text)
expected_slice = torch.tensor(
[[[-39.7358, -44.8562, -36.6253], [-62.3605, -64.5908, -59.0069], [-74.6127, -68.9724, -71.7150]]],
device=torch_device,
)
self.assertTrue(torch.allclose(outputs.logits[0][:, 1:4, 1:4], expected_slice, atol=1e-4))

View File

@@ -0,0 +1,211 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the MgpstrProcessor. """
import json
import os
import shutil
import tempfile
import unittest
import numpy as np
import pytest
from transformers import MgpstrTokenizer
from transformers.models.mgp_str.tokenization_mgp_str import VOCAB_FILES_NAMES
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import MgpstrProcessor, ViTImageProcessor
@require_torch
@require_vision
class MgpstrProcessorTest(unittest.TestCase):
image_processing_class = ViTImageProcessor if is_vision_available() else None
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def setUp(self):
self.image_size = (3, 32, 128)
self.tmpdirname = tempfile.mkdtemp()
# fmt: off
vocab = ['[GO]', '[s]', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
# fmt: on
vocab_tokens = dict(zip(vocab, range(len(vocab))))
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
image_processor_map = {
"do_normalize": False,
"do_resize": True,
"feature_extractor_type": "ViTFeatureExtractor",
"resample": 3,
"size": {"height": 32, "width": 128},
}
self.image_processor_file = os.path.join(self.tmpdirname, IMAGE_PROCESSOR_NAME)
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(image_processor_map, fp)
def get_tokenizer(self, **kwargs):
return MgpstrTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_image_processor(self, **kwargs):
return ViTImageProcessor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def prepare_image_inputs(self):
"""This function prepares a list of PIL images."""
image_input = np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)
image_input = Image.fromarray(np.moveaxis(image_input, 0, -1))
return image_input
def test_save_load_pretrained_default(self):
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname)
processor = MgpstrProcessor.from_pretrained(self.tmpdirname, use_fast=False)
self.assertEqual(processor.char_tokenizer.get_vocab(), tokenizer.get_vocab())
self.assertIsInstance(processor.char_tokenizer, MgpstrTokenizer)
self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string())
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
def test_save_load_pretrained_additional_features(self):
tokenizer = self.get_tokenizer()
image_processor = self.get_image_processor()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)")
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
processor = MgpstrProcessor.from_pretrained(
self.tmpdirname, bos_token="(BOS)", eos_token="(EOS)", do_normalize=False, padding_value=1.0
)
self.assertEqual(processor.char_tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.char_tokenizer, MgpstrTokenizer)
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
self.assertIsInstance(processor.image_processor, ViTImageProcessor)
def test_image_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
image_input = self.prepare_image_inputs()
input_image_proc = image_processor(image_input, return_tensors="np")
input_processor = processor(images=image_input, return_tensors="np")
for key in input_image_proc.keys():
self.assertAlmostEqual(input_image_proc[key].sum(), input_processor[key].sum(), delta=1e-2)
def test_tokenizer(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "test"
encoded_processor = processor(text=input_str)
encoded_tok = tokenizer(input_str)
for key in encoded_tok.keys():
self.assertListEqual(encoded_tok[key], encoded_processor[key])
def test_processor(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = "test"
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), ["pixel_values", "labels"])
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_tokenizer_decode(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
predicted_ids = [[1, 4, 5, 8, 1, 0, 8], [3, 4, 3, 1, 1, 8, 9], [3, 4, 3, 1, 1, 8, 9]]
decoded_processor = processor.char_decode(predicted_ids)
decoded_tok = tokenizer.batch_decode(predicted_ids)
decode_strs = [seq.replace(" ", "") for seq in decoded_tok]
self.assertListEqual(decode_strs, decoded_processor)
def test_model_input_names(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
input_str = None
image_input = self.prepare_image_inputs()
inputs = processor(text=input_str, images=image_input)
self.assertListEqual(list(inputs.keys()), processor.model_input_names)
def test_processor_batch_decode(self):
image_processor = self.get_image_processor()
tokenizer = self.get_tokenizer()
processor = MgpstrProcessor(tokenizer=tokenizer, image_processor=image_processor)
char_input = torch.randn(1, 27, 38)
bpe_input = torch.randn(1, 27, 50257)
wp_input = torch.randn(1, 27, 30522)
results = processor.batch_decode([char_input, bpe_input, wp_input])
self.assertListEqual(list(results.keys()), ["generated_text", "scores", "char_preds", "bpe_preds", "wp_preds"])

View File

@@ -0,0 +1,96 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import unittest
from transformers import MgpstrTokenizer
from transformers.models.mgp_str.tokenization_mgp_str import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
class MgpstrTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MgpstrTokenizer
test_rust_tokenizer = False
from_pretrained_kwargs = {}
test_seq2seq = False
def setUp(self):
super().setUp()
# fmt: off
vocab = ['[GO]', '[s]', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
# fmt: on
vocab_tokens = dict(zip(vocab, range(len(vocab))))
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding="utf-8") as fp:
fp.write(json.dumps(vocab_tokens) + "\n")
def get_tokenizer(self, **kwargs):
return MgpstrTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self, tokenizer):
input_text = "tester"
output_text = "tester"
return input_text, output_text
@unittest.skip("MGP-STR always lower cases letters.")
def test_added_tokens_do_lower_case(self):
pass
def test_add_special_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"cls_token": special_token})
encoded_special_token = tokenizer.encode([special_token], add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True)
self.assertTrue(special_token not in decoded)
def test_internal_consistency(self):
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
input_text, output_text = self.get_input_output_texts(tokenizer)
tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(input_text, add_special_tokens=False)
self.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
self.assertNotEqual(len(tokens_2), 0)
text_2 = tokenizer.decode(ids)
self.assertIsInstance(text_2, str)
self.assertEqual(text_2.replace(" ", ""), output_text)
@unittest.skip("MGP-STR tokenizer only handles one sequence.")
def test_maximum_encoding_length_pair_input(self):
pass
@unittest.skip("inputs cannot be pretokenized in MgpstrTokenizer")
def test_pretokenized_inputs(self):
pass