Check TF ops for ONNX compliance (#10025)
* Add check-ops script * Finish to implement check_tf_ops and start the test * Make the test mandatory only for BERT * Update tf_ops folder * Remove useless classes * Add the ONNX test for GPT2 and BART * Add a onnxruntime slow test + better opset flexibility * Fix test + apply style * fix tests * Switch min opset from 12 to 10 * Update src/transformers/file_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Fix GPT2 * Remove extra shape_list usage * Fix GPT2 * Address Morgan's comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -241,6 +241,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFAlbertModelTester(self)
|
||||
|
||||
@@ -178,6 +178,8 @@ class TFBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = True
|
||||
onnx_min_opset = 10
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBartModelTester(self)
|
||||
|
||||
@@ -274,6 +274,8 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = True
|
||||
onnx_min_opset = 10
|
||||
|
||||
# special case for ForPreTraining model
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
|
||||
@@ -177,6 +177,7 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotModelTester(self)
|
||||
|
||||
@@ -179,6 +179,7 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFBlenderbotSmallForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFBlenderbotSmallModelTester(self)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
@@ -24,7 +25,7 @@ from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_tf, slow
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, is_pt_tf_cross_test, require_onnx, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@@ -201,6 +202,67 @@ class TFModelTesterMixin:
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
self.assertTrue(os.path.exists(saved_model_dir))
|
||||
|
||||
def test_onnx_compliancy(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
INTERNAL_OPS = [
|
||||
"Assert",
|
||||
"AssignVariableOp",
|
||||
"EmptyTensorList",
|
||||
"ReadVariableOp",
|
||||
"ResourceGather",
|
||||
"TruncatedNormal",
|
||||
"VarHandleOp",
|
||||
"VarIsInitializedOp",
|
||||
]
|
||||
onnx_ops = []
|
||||
|
||||
with open(os.path.join(".", "utils", "tf_ops", "onnx.json")) as f:
|
||||
onnx_opsets = json.load(f)["opsets"]
|
||||
|
||||
for i in range(1, self.onnx_min_opset + 1):
|
||||
onnx_ops.extend(onnx_opsets[str(i)])
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model_op_names = set()
|
||||
|
||||
with tf.Graph().as_default() as g:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
for op in g.get_operations():
|
||||
model_op_names.add(op.node_def.op)
|
||||
|
||||
model_op_names = sorted(model_op_names)
|
||||
incompatible_ops = []
|
||||
|
||||
for op in model_op_names:
|
||||
if op not in onnx_ops and op not in INTERNAL_OPS:
|
||||
incompatible_ops.append(op)
|
||||
|
||||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||
|
||||
@require_onnx
|
||||
@slow
|
||||
def test_onnx_runtime_optimize(self):
|
||||
if not self.test_onnx:
|
||||
return
|
||||
|
||||
import keras2onnx
|
||||
import onnxruntime
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
|
||||
onnx_model = keras2onnx.convert_keras(model, model.name, target_opset=self.onnx_min_opset)
|
||||
|
||||
onnxruntime.InferenceSession(onnx_model.SerializeToString())
|
||||
|
||||
@slow
|
||||
def test_saved_model_creation_extended(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -239,6 +239,7 @@ class TFConvBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFConvBertModelTester(self)
|
||||
|
||||
@@ -174,6 +174,7 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFCTRLModel, TFCTRLLMHeadModel, TFCTRLForSequenceClassification) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFCTRLLMHeadModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFCTRLModelTester(self)
|
||||
|
||||
@@ -184,6 +184,7 @@ class TFDistilBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else None
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDistilBertModelTester(self)
|
||||
|
||||
@@ -188,6 +188,7 @@ class TFDPRModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
test_missing_keys = False
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFDPRModelTester(self)
|
||||
|
||||
@@ -206,6 +206,7 @@ class TFElectraModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFElectraModelTester(self)
|
||||
|
||||
@@ -292,6 +292,7 @@ class TFFlaubertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFFlaubertWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFlaubertModelTester(self)
|
||||
|
||||
@@ -339,6 +339,7 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self)
|
||||
@@ -382,6 +383,7 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFFunnelBaseModel, TFFunnelForMultipleChoice, TFFunnelForSequenceClassification) if is_tf_available() else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFFunnelModelTester(self, base=True)
|
||||
|
||||
@@ -333,6 +333,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = True
|
||||
onnx_min_opset = 10
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFGPT2ModelTester(self)
|
||||
|
||||
@@ -195,6 +195,8 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLEDModelTester(self)
|
||||
|
||||
@@ -297,6 +297,8 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLongformerModelTester(self)
|
||||
|
||||
@@ -362,6 +362,7 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFLxmertModel, TFLxmertForPreTraining) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLxmertModelTester(self)
|
||||
|
||||
@@ -179,6 +179,7 @@ class TFMarianModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMarianModelTester(self)
|
||||
|
||||
@@ -181,6 +181,7 @@ class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMBartModelTester(self)
|
||||
|
||||
@@ -56,6 +56,7 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
class TFMobileBertModelTester(object):
|
||||
def __init__(
|
||||
|
||||
@@ -199,6 +199,7 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFMPNetModelTester(self)
|
||||
|
||||
@@ -203,6 +203,7 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFOpenAIGPTLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFOpenAIGPTModelTester(self)
|
||||
|
||||
@@ -177,6 +177,7 @@ class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFPegasusModelTester(self)
|
||||
|
||||
@@ -186,6 +186,7 @@ class TFRobertaModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFRobertaModelTester(self)
|
||||
|
||||
@@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5ModelTester(self)
|
||||
@@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = False
|
||||
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFT5EncoderOnlyModelTester(self)
|
||||
|
||||
@@ -164,6 +164,7 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# TODO: add this test when TFTransfoXLLMHead has a linear output layer implemented
|
||||
test_resize_embeddings = False
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFTransfoXLModelTester(self)
|
||||
|
||||
@@ -294,6 +294,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLMModelTester(self)
|
||||
|
||||
@@ -348,6 +348,7 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
(TFXLNetLMHeadModel,) if is_tf_available() else ()
|
||||
) # TODO (PVP): Check other models whether language generation is also applicable
|
||||
test_head_masking = False
|
||||
test_onnx = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFXLNetModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user