Fixes torch jit tracing for LayoutLMv2 model (re-open) (#18313)

* Fixes torch jit tracing for LayoutLMv2 model.
Pytorch seems to reuse memory for input_shape which caused a mismatch in shapes later in the forward pass.

* Fixed code quality

* avoid unneeded allocation of vector for shape
This commit is contained in:
Mikkel Denker
2022-07-27 12:38:40 +02:00
committed by GitHub
parent 1d71ad8905
commit 70e7d1d656
3 changed files with 21 additions and 11 deletions

View File

@@ -260,7 +260,7 @@ class LayoutLMv2ModelTester:
class LayoutLMv2ModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_torchscript = True
test_mismatched_shapes = False
all_model_classes = (