update from keras2onnx to tf2onnx (#15162)
This commit is contained in:
8
setup.py
8
setup.py
@@ -114,7 +114,6 @@ _deps = [
|
|||||||
"jax>=0.2.8",
|
"jax>=0.2.8",
|
||||||
"jaxlib>=0.1.65",
|
"jaxlib>=0.1.65",
|
||||||
"jieba",
|
"jieba",
|
||||||
"keras2onnx",
|
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy>=1.17",
|
"numpy>=1.17",
|
||||||
"onnxconverter-common",
|
"onnxconverter-common",
|
||||||
@@ -147,6 +146,7 @@ _deps = [
|
|||||||
"starlette",
|
"starlette",
|
||||||
"tensorflow-cpu>=2.3",
|
"tensorflow-cpu>=2.3",
|
||||||
"tensorflow>=2.3",
|
"tensorflow>=2.3",
|
||||||
|
"tf2onnx",
|
||||||
"timeout-decorator",
|
"timeout-decorator",
|
||||||
"timm",
|
"timm",
|
||||||
"tokenizers>=0.10.1",
|
"tokenizers>=0.10.1",
|
||||||
@@ -229,8 +229,8 @@ extras = {}
|
|||||||
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
|
extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic")
|
||||||
extras["sklearn"] = deps_list("scikit-learn")
|
extras["sklearn"] = deps_list("scikit-learn")
|
||||||
|
|
||||||
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx")
|
extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx")
|
||||||
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx")
|
extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx")
|
||||||
|
|
||||||
extras["torch"] = deps_list("torch")
|
extras["torch"] = deps_list("torch")
|
||||||
|
|
||||||
@@ -243,7 +243,7 @@ else:
|
|||||||
|
|
||||||
extras["tokenizers"] = deps_list("tokenizers")
|
extras["tokenizers"] = deps_list("tokenizers")
|
||||||
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
|
||||||
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"]
|
extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"]
|
||||||
extras["modelcreation"] = deps_list("cookiecutter")
|
extras["modelcreation"] = deps_list("cookiecutter")
|
||||||
|
|
||||||
extras["sagemaker"] = deps_list("sagemaker")
|
extras["sagemaker"] = deps_list("sagemaker")
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
|
|||||||
|
|
||||||
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
||||||
"""
|
"""
|
||||||
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR
|
Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nlp: The pipeline to be exported
|
nlp: The pipeline to be exported
|
||||||
@@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
|
|||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from keras2onnx import __version__ as k2ov
|
from tf2onnx import __version__ as t2ov
|
||||||
from keras2onnx import convert_keras, save_model
|
from tf2onnx import convert_keras, save_model
|
||||||
|
|
||||||
print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}")
|
print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
|
||||||
|
|
||||||
# Build
|
# Build
|
||||||
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ deps = {
|
|||||||
"jax": "jax>=0.2.8",
|
"jax": "jax>=0.2.8",
|
||||||
"jaxlib": "jaxlib>=0.1.65",
|
"jaxlib": "jaxlib>=0.1.65",
|
||||||
"jieba": "jieba",
|
"jieba": "jieba",
|
||||||
"keras2onnx": "keras2onnx",
|
|
||||||
"nltk": "nltk",
|
"nltk": "nltk",
|
||||||
"numpy": "numpy>=1.17",
|
"numpy": "numpy>=1.17",
|
||||||
"onnxconverter-common": "onnxconverter-common",
|
"onnxconverter-common": "onnxconverter-common",
|
||||||
@@ -57,6 +56,7 @@ deps = {
|
|||||||
"starlette": "starlette",
|
"starlette": "starlette",
|
||||||
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
||||||
"tensorflow": "tensorflow>=2.3",
|
"tensorflow": "tensorflow>=2.3",
|
||||||
|
"tf2onnx": "tf2onnx",
|
||||||
"timeout-decorator": "timeout-decorator",
|
"timeout-decorator": "timeout-decorator",
|
||||||
"timm": "timm",
|
"timm": "timm",
|
||||||
"tokenizers": "tokenizers>=0.10.1",
|
"tokenizers": "tokenizers>=0.10.1",
|
||||||
|
|||||||
@@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError:
|
|||||||
_sympy_available = False
|
_sympy_available = False
|
||||||
|
|
||||||
|
|
||||||
_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
|
_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None
|
||||||
try:
|
try:
|
||||||
_keras2onnx_version = importlib_metadata.version("keras2onnx")
|
_tf2onnx_version = importlib_metadata.version("tf2onnx")
|
||||||
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
|
logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}")
|
||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
_keras2onnx_available = False
|
_tf2onnx_available = False
|
||||||
|
|
||||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||||
try:
|
try:
|
||||||
@@ -429,8 +429,8 @@ def is_coloredlogs_available():
|
|||||||
return _coloredlogs_available
|
return _coloredlogs_available
|
||||||
|
|
||||||
|
|
||||||
def is_keras2onnx_available():
|
def is_tf2onnx_available():
|
||||||
return _keras2onnx_available
|
return _tf2onnx_available
|
||||||
|
|
||||||
|
|
||||||
def is_onnx_available():
|
def is_onnx_available():
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from .file_utils import (
|
|||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
is_keras2onnx_available,
|
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
@@ -49,6 +48,7 @@ from .file_utils import (
|
|||||||
is_soundfile_availble,
|
is_soundfile_availble,
|
||||||
is_spacy_available,
|
is_spacy_available,
|
||||||
is_tensorflow_probability_available,
|
is_tensorflow_probability_available,
|
||||||
|
is_tf2onnx_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_timm_available,
|
is_timm_available,
|
||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
@@ -246,9 +246,9 @@ def require_rjieba(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_keras2onnx(test_case):
|
def require_tf2onnx(test_case):
|
||||||
if not is_keras2onnx_available():
|
if not is_tf2onnx_available():
|
||||||
return unittest.skip("test requires keras2onnx")(test_case)
|
return unittest.skip("test requires tf2onnx")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|||||||
@@ -36,8 +36,8 @@ from transformers.testing_utils import (
|
|||||||
_tf_gpu_memory_limit,
|
_tf_gpu_memory_limit,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_keras2onnx,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
|
require_tf2onnx,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
@@ -254,14 +254,14 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||||
|
|
||||||
@require_keras2onnx
|
@require_tf2onnx
|
||||||
@slow
|
@slow
|
||||||
def test_onnx_runtime_optimize(self):
|
def test_onnx_runtime_optimize(self):
|
||||||
if not self.test_onnx:
|
if not self.test_onnx:
|
||||||
return
|
return
|
||||||
|
|
||||||
import keras2onnx
|
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
|
import tf2onnx
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -269,9 +269,9 @@ class TFModelTesterMixin:
|
|||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model(model.dummy_inputs)
|
model(model.dummy_inputs)
|
||||||
|
|
||||||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
|
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||||
|
|
||||||
onnxruntime.InferenceSession(onnx_model.SerializeToString())
|
onnxruntime.InferenceSession(onnx_model_proto.SerializeToString())
|
||||||
|
|
||||||
def test_keras_save_load(self):
|
def test_keras_save_load(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user