Add LayoutLMForQuestionAnswering model (#18407)
* Add LayoutLMForQuestionAnswering model * Fix output * Remove TF TODOs * Add test cases * Add docs * TF implementation * Fix PT/TF equivalence * Fix loss * make fixup * Fix up documentation code examples * Fix up documentation examples + test them * Remove LayoutLMForQuestionAnswering from the auto mapping * Docstrings * Add better docstrings * Undo whitespace changes * Update tokenizers in comments * Fixup code and remove `from_pt=True` * Fix tests * Revert some unexpected docstring changes * Fix tests by overriding _prepare_for_class Co-authored-by: Ankur Goyal <ankur@impira.com>
This commit is contained in:
@@ -13,10 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import LayoutLMConfig, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -27,7 +28,11 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForQuestionAnswering,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMModel,
|
||||
@@ -181,6 +186,23 @@ class LayoutLMModelTester:
|
||||
result = model(input_ids, bbox, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LayoutLMForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
bbox=bbox,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -211,6 +233,7 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
LayoutLMForTokenClassification,
|
||||
LayoutLMForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else None
|
||||
@@ -246,6 +269,34 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
if return_labels:
|
||||
if model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
|
||||
*get_values(MODEL_FOR_MASKED_LM_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class.__name__ == "LayoutLMForQuestionAnswering":
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
|
||||
def prepare_layoutlm_batch_inputs():
|
||||
# Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on:
|
||||
@@ -337,3 +388,18 @@ class LayoutLMModelIntegrationTest(unittest.TestCase):
|
||||
logits = outputs.logits
|
||||
expected_shape = torch.Size((2, 25, 13))
|
||||
self.assertEqual(logits.shape, expected_shape)
|
||||
|
||||
@slow
|
||||
def test_forward_pass_question_answering(self):
|
||||
# initialize model with randomly initialized token classification head
|
||||
model = LayoutLMForQuestionAnswering.from_pretrained("microsoft/layoutlm-base-uncased").to(torch_device)
|
||||
|
||||
input_ids, attention_mask, bbox, token_type_ids, labels = prepare_layoutlm_batch_inputs()
|
||||
|
||||
# forward pass
|
||||
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
|
||||
# test the shape of the logits
|
||||
expected_shape = torch.Size((2, 25))
|
||||
self.assertEqual(outputs.start_logits.shape, expected_shape)
|
||||
self.assertEqual(outputs.end_logits.shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user