Fixes torch jit tracing for LayoutLMv2 model (re-open) (#18313)

* Fixes torch jit tracing for LayoutLMv2 model.
Pytorch seems to reuse memory for input_shape which caused a mismatch in shapes later in the forward pass.

* Fixed code quality

* avoid unneeded allocation of vector for shape
This commit is contained in:
Mikkel Denker
2022-07-27 12:38:40 +02:00
committed by GitHub
parent 1d71ad8905
commit 70e7d1d656
3 changed files with 21 additions and 11 deletions

View File

@@ -260,7 +260,7 @@ class LayoutLMv2ModelTester:
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_torchscript = True
test_mismatched_shapes = False
all_model_classes = (

View File

@@ -648,6 +648,13 @@ class ModelTesterMixin:
traced_model = torch.jit.trace(
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
)
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
image = inputs["image"].tensor
traced_model = torch.jit.trace(
model, (input_ids, bbox, image), check_trace=False
) # when traced model is checked, an error is produced due to name mangling
else:
main_input = inputs[main_input_name]
traced_model = torch.jit.trace(model, main_input)