diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0a55f1d11c..145157c54d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -658,6 +658,7 @@ class ModelTesterMixin: attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] + model(main_input, attention_mask, decoder_input_ids, decoder_attention_mask) traced_model = torch.jit.trace( model, (main_input, attention_mask, decoder_input_ids, decoder_attention_mask) ) @@ -665,11 +666,13 @@ class ModelTesterMixin: input_ids = inputs["input_ids"] bbox = inputs["bbox"] image = inputs["image"].tensor + model(input_ids, bbox, image) traced_model = torch.jit.trace( model, (input_ids, bbox, image), check_trace=False ) # when traced model is checked, an error is produced due to name mangling else: main_input = inputs[main_input_name] + model(main_input) traced_model = torch.jit.trace(model, main_input) except RuntimeError: self.fail("Couldn't trace module.")