From 3a1aeea3c5d9442dc891fcd4180d4a089dda567a Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 21 Oct 2022 16:23:13 +0200 Subject: [PATCH] Fix CTRL `test_torchscrip_xxx` CI by updating `_create_and_check_torchscript` (#19786) * Run inputs before trace * Run inputs before trace Co-authored-by: ydshieh --- tests/test_modeling_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0a55f1d11c..145157c54d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -658,6 +658,7 @@ class ModelTesterMixin: attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] + model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) traced_model = torch.jit.trace( model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) ) @@ -665,11 +666,13 @@ class ModelTesterMixin: input_ids = inputs["input_ids"] bbox = inputs["bbox"] image = inputs["image"].tensor + model(input_ids, bbox, image) traced_model = torch.jit.trace( model, (input_ids, bbox, image), check_trace=False ) # when traced model is checked, an error is produced due to name mangling else: main_input = inputs[main_input_name] + model(main_input) traced_model = torch.jit.trace(model, main_input) except RuntimeError: self.fail("Couldn't trace module.")