Fix TF GPT2 test_onnx_runtime_optimize (#17874)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import GPT2Config, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.testing_utils import require_tf, require_tf2onnx, slow
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||
@@ -444,6 +444,32 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
||||
model = TFGPT2Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
# overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument
|
||||
# `batch_dims` > 0"
|
||||
@require_tf2onnx
|
||||
@slow
|
||||
def test_onnx_runtime_optimize(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
import onnxruntime
|
||||
import tf2onnx
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
# Skip these 2 classes which uses `tf.gather` with `batch_dims=1`
|
||||
if model_class in [TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel]:
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||
|
||||
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user