Check TF ops for ONNX compliance (#10025)

* Add check-ops script

* Finish to implement check_tf_ops and start the test

* Make the test mandatory only for BERT

* Update tf_ops folder

* Remove useless classes

* Add the ONNX test for GPT2 and BART

* Add a onnxruntime slow test + better opset flexibility

* Fix test + apply style

* fix tests

* Switch min opset from 12 to 10

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Fix GPT2

* Remove extra shape_list usage

* Fix GPT2

* Address Morgan's comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Julien Plu
2021-02-15 13:55:10 +01:00
committed by GitHub
parent 93bd2f7099
commit c8d3fa0dfd
33 changed files with 468 additions and 17 deletions

View File

@@ -16,6 +16,7 @@
import copy
import inspect
import json
import os
import random
import tempfile
@@ -24,7 +25,7 @@ from importlib import import_module
from typing import List, Tuple
from transformers import is_tf_available
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow
if is_tf_available():
@@ -201,6 +202,67 @@ class TFModelTesterMixin:
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
self.assertTrue(os.path.exists(saved_model_dir))
def test_onnx_compliancy(self):
if not self.test_onnx:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
INTERNAL_OPS = [
"Assert",
"AssignVariableOp",
"EmptyTensorList",
"ReadVariableOp",
"ResourceGather",
"TruncatedNormal",
"VarHandleOp",
"VarIsInitializedOp",
]
onnx_ops = []
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
onnx_opsets = json.load(f)["opsets"]
for i in range(1, self.onnx_min_opset + 1):
onnx_ops.extend(onnx_opsets[str(i)])
for model_class in self.all_model_classes:
model_op_names = set()
with tf.Graph().as_default() as g:
model = model_class(config)
model(model.dummy_inputs)
for op in g.get_operations():
model_op_names.add(op.node_def.op)
model_op_names = sorted(model_op_names)
incompatible_ops = []
for op in model_op_names:
if op not in onnx_ops and op not in INTERNAL_OPS:
incompatible_ops.append(op)
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
@require_onnx
@slow
def test_onnx_runtime_optimize(self):
if not self.test_onnx:
return
import keras2onnx
import onnxruntime
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model(model.dummy_inputs)
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
onnxruntime.InferenceSession(onnx_model.SerializeToString())
@slow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()