From 401fcca6c561d61db6ce25d9b1cebb75325a034f Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 27 Jun 2022 09:27:30 +0200 Subject: [PATCH] Fix TF GPT2 test_onnx_runtime_optimize (#17874) Co-authored-by: ydshieh --- tests/models/gpt2/test_modeling_tf_gpt2.py | 28 +++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index efa3f0ac1c..b72c3c08eb 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -16,7 +16,7 @@ import unittest from transformers import GPT2Config, is_tf_available -from transformers.testing_utils import require_tf, slow +from transformers.testing_utils import require_tf, require_tf2onnx, slow from ...test_configuration_common import ConfigTester from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask @@ -444,6 +444,32 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC model = TFGPT2Model.from_pretrained(model_name) self.assertIsNotNone(model) + # overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument + # `batch_dims` > 0" + @require_tf2onnx + @slow + def test_onnx_runtime_optimize(self): + if not self.test_onnx: + return + + import onnxruntime + import tf2onnx + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + + # Skip these 2 classes which uses `tf.gather` with `batch_dims=1` + if model_class in [TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel]: + continue + + model = model_class(config) + model(model.dummy_inputs) + + onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset) + + onnxruntime.InferenceSession(onnx_model_proto.SerializeToString()) + @require_tf class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):