Fixes torch jit tracing for LayoutLMv2 model (re-open) (#18313)
* Fixes torch jit tracing for LayoutLMv2 model. Pytorch seems to reuse memory for input_shape which caused a mismatch in shapes later in the forward pass. * Fixed code quality * avoid unneeded allocation of vector for shape
This commit is contained in:
@@ -260,7 +260,7 @@ class LayoutLMv2ModelTester:
|
||||
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_torchscript = True
|
||||
test_mismatched_shapes = False
|
||||
|
||||
all_model_classes = (
|
||||
|
||||
@@ -648,6 +648,13 @@ class ModelTesterMixin:
|
||||
traced_model = torch.jit.trace(
|
||||
model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask)
|
||||
)
|
||||
elif "bbox" in inputs and "image" in inputs: # LayoutLMv2 requires additional inputs
|
||||
input_ids = inputs["input_ids"]
|
||||
bbox = inputs["bbox"]
|
||||
image = inputs["image"].tensor
|
||||
traced_model = torch.jit.trace(
|
||||
model, (input_ids, bbox, image), check_trace=False
|
||||
) # when traced model is checked, an error is produced due to name mangling
|
||||
else:
|
||||
main_input = inputs[main_input_name]
|
||||
traced_model = torch.jit.trace(model, main_input)
|
||||
|
||||
Reference in New Issue
Block a user