Check TF ops for ONNX compliance (#10025)

* Add check-ops script

* Finish to implement check_tf_ops and start the test

* Make the test mandatory only for BERT

* Update tf_ops folder

* Remove useless classes

* Add the ONNX test for GPT2 and BART

* Add a onnxruntime slow test + better opset flexibility

* Fix test + apply style

* fix tests

* Switch min opset from 12 to 10

* Update src/transformers/file_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Fix GPT2

* Remove extra shape_list usage

* Fix GPT2

* Address Morgan's comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Julien Plu
2021-02-15 13:55:10 +01:00
committed by GitHub
parent 93bd2f7099
commit c8d3fa0dfd
33 changed files with 468 additions and 17 deletions

View File

@@ -249,6 +249,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
all_model_classes = (TFT5Model, TFT5ForConditionalGeneration) if is_tf_available() else ()
all_generative_model_classes = (TFT5ForConditionalGeneration,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFT5ModelTester(self)
@@ -427,6 +428,7 @@ class TFT5EncoderOnlyModelTest(TFModelTesterMixin, unittest.TestCase):
is_encoder_decoder = False
all_model_classes = (TFT5EncoderModel,) if is_tf_available() else ()
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFT5EncoderOnlyModelTester(self)