update from keras2onnx to tf2onnx (#15162)
This commit is contained in:
@@ -36,8 +36,8 @@ from transformers.testing_utils import (
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_keras2onnx,
|
||||
require_tf,
|
||||
require_tf2onnx,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
@@ -254,14 +254,14 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||
|
||||
@require_keras2onnx
|
||||
@require_tf2onnx
|
||||
@slow
|
||||
def test_onnx_runtime_optimize(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
import keras2onnx
|
||||
import onnxruntime
|
||||
import tf2onnx
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -269,9 +269,9 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
|
||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||
|
||||
onnxruntime.InferenceSession(onnx_model.SerializeToString())
|
||||
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
|
||||
|
||||
def test_keras_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user