From ebc4edfe7a4dfb820bf7f4faa02fb6de479b6e36 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 14 Jan 2022 17:35:39 +0000 Subject: [PATCH] update from keras2onnx to tf2onnx (#15162) --- setup.py | 8 ++++---- src/transformers/convert_graph_to_onnx.py | 8 ++++---- src/transformers/dependency_versions_table.py | 2 +- src/transformers/file_utils.py | 12 ++++++------ src/transformers/testing_utils.py | 8 ++++---- tests/test_modeling_tf_common.py | 10 +++++----- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index 015b7a8042..22bd308ae9 100644 --- a/setup.py +++ b/setup.py @@ -114,7 +114,6 @@ _deps = [ "jax>=0.2.8", "jaxlib>=0.1.65", "jieba", - "keras2onnx", "nltk", "numpy>=1.17", "onnxconverter-common", @@ -147,6 +146,7 @@ _deps = [ "starlette", "tensorflow-cpu>=2.3", "tensorflow>=2.3", + "tf2onnx", "timeout-decorator", "timm", "tokenizers>=0.10.1", @@ -229,8 +229,8 @@ extras = {} extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic") extras["sklearn"] = deps_list("scikit-learn") -extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "keras2onnx") -extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "keras2onnx") +extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx") +extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx") extras["torch"] = deps_list("torch") @@ -243,7 +243,7 @@ else: extras["tokenizers"] = deps_list("tokenizers") extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools") -extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"] +extras["onnx"] = deps_list("onnxconverter-common", "tf2onnx") + extras["onnxruntime"] extras["modelcreation"] = deps_list("cookiecutter") extras["sagemaker"] = deps_list("sagemaker") diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py index 5ca722b8f0..e58587837e 100644 --- a/src/transformers/convert_graph_to_onnx.py +++ b/src/transformers/convert_graph_to_onnx.py @@ -294,7 +294,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): """ - Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR + Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR) Args: nlp: The pipeline to be exported @@ -312,10 +312,10 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path): try: import tensorflow as tf - from keras2onnx import __version__ as k2ov - from keras2onnx import convert_keras, save_model + from tf2onnx import __version__ as t2ov + from tf2onnx import convert_keras, save_model - print(f"Using framework TensorFlow: {tf.version.VERSION}, keras2onnx: {k2ov}") + print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}") # Build input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf") diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index ee8b22b30c..28662db09e 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -24,7 +24,6 @@ deps = { "jax": "jax>=0.2.8", "jaxlib": "jaxlib>=0.1.65", "jieba": "jieba", - "keras2onnx": "keras2onnx", "nltk": "nltk", "numpy": "numpy>=1.17", "onnxconverter-common": "onnxconverter-common", @@ -57,6 +56,7 @@ deps = { "starlette": "starlette", "tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow": "tensorflow>=2.3", + "tf2onnx": "tf2onnx", "timeout-decorator": "timeout-decorator", "timm": "timm", "tokenizers": "tokenizers>=0.10.1", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 408df727ac..f8599796cb 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -175,12 +175,12 @@ except importlib_metadata.PackageNotFoundError: _sympy_available = False -_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None +_tf2onnx_available = importlib.util.find_spec("tf2onnx") is not None try: - _keras2onnx_version = importlib_metadata.version("keras2onnx") - logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}") + _tf2onnx_version = importlib_metadata.version("tf2onnx") + logger.debug(f"Successfully imported tf2onnx version {_tf2onnx_version}") except importlib_metadata.PackageNotFoundError: - _keras2onnx_available = False + _tf2onnx_available = False _onnx_available = importlib.util.find_spec("onnxruntime") is not None try: @@ -429,8 +429,8 @@ def is_coloredlogs_available(): return _coloredlogs_available -def is_keras2onnx_available(): - return _keras2onnx_available +def is_tf2onnx_available(): + return _tf2onnx_available def is_onnx_available(): diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index a0654bd3f4..44bbbd7b99 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -35,7 +35,6 @@ from .file_utils import ( is_faiss_available, is_flax_available, is_ftfy_available, - is_keras2onnx_available, is_librosa_available, is_onnx_available, is_pandas_available, @@ -49,6 +48,7 @@ from .file_utils import ( is_soundfile_availble, is_spacy_available, is_tensorflow_probability_available, + is_tf2onnx_available, is_tf_available, is_timm_available, is_tokenizers_available, @@ -246,9 +246,9 @@ def require_rjieba(test_case): return test_case -def require_keras2onnx(test_case): - if not is_keras2onnx_available(): - return unittest.skip("test requires keras2onnx")(test_case) +def require_tf2onnx(test_case): + if not is_tf2onnx_available(): + return unittest.skip("test requires tf2onnx")(test_case) else: return test_case diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 57c97455ae..2c7e8fb103 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -36,8 +36,8 @@ from transformers.testing_utils import ( _tf_gpu_memory_limit, is_pt_tf_cross_test, is_staging_test, - require_keras2onnx, require_tf, + require_tf2onnx, slow, ) from transformers.utils import logging @@ -254,14 +254,14 @@ class TFModelTesterMixin: self.assertEqual(len(incompatible_ops), 0, incompatible_ops) - @require_keras2onnx + @require_tf2onnx @slow def test_onnx_runtime_optimize(self): if not self.test_onnx: return - import keras2onnx import onnxruntime + import tf2onnx config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -269,9 +269,9 @@ class TFModelTesterMixin: model = model_class(config) model(model.dummy_inputs) - onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset) + onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset) - onnxruntime.InferenceSession(onnx_model.SerializeToString()) + onnxruntime.InferenceSession(onnx_model_proto.SerializeToString()) def test_keras_save_load(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()