update from keras2onnx to tf2onnx (#15162)

This commit is contained in:
Joao Gante
2022-01-14 17:35:39 +00:00
committed by GitHub
parent 1b730c3d11
commit ebc4edfe7a
6 changed files with 24 additions and 24 deletions

View File

@@ -36,8 +36,8 @@ from transformers.testing_utils import (
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
require_keras2onnx,
require_tf,
require_tf2onnx,
slow,
)
from transformers.utils import logging
@@ -254,14 +254,14 @@ class TFModelTesterMixin:
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
@require_keras2onnx
@require_tf2onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return
import keras2onnx
import onnxruntime
import tf2onnx
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -269,9 +269,9 @@ class TFModelTesterMixin:
model = model_class(config)
model(model.dummy_inputs)
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
onnxruntime.InferenceSession(onnx_model.SerializeToString())
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
def test_keras_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()