Fix TF GPT2 test_onnx_runtime_optimize (#17874)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import GPT2Config, is_tf_available
|
from transformers import GPT2Config, is_tf_available
|
||||||
from transformers.testing_utils import require_tf, slow
|
from transformers.testing_utils import require_tf, require_tf2onnx, slow
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
|
||||||
@@ -444,6 +444,32 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, unittest.TestC
|
|||||||
model = TFGPT2Model.from_pretrained(model_name)
|
model = TFGPT2Model.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
# overwrite from common since ONNX runtime optimization doesn't work with tf.gather() when the argument
|
||||||
|
# `batch_dims` > 0"
|
||||||
|
@require_tf2onnx
|
||||||
|
@slow
|
||||||
|
def test_onnx_runtime_optimize(self):
|
||||||
|
if not self.test_onnx:
|
||||||
|
return
|
||||||
|
|
||||||
|
import onnxruntime
|
||||||
|
import tf2onnx
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
|
||||||
|
# Skip these 2 classes which uses `tf.gather` with `batch_dims=1`
|
||||||
|
if model_class in [TFGPT2ForSequenceClassification, TFGPT2DoubleHeadsModel]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = model_class(config)
|
||||||
|
model(model.dummy_inputs)
|
||||||
|
|
||||||
|
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||||
|
|
||||||
|
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user