Onnx fix test (#10663)

* Allow to pass kwargs to model's from_pretrained when using pipeline.

* Disable the use of past_keys_values for GPT2 when exporting to ONNX.

* style

* Remove comment.

* Appease the documentation gods

* Fix style

Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Funtowicz Morgan
2021-03-11 19:38:29 +01:00
committed by GitHub
parent a637ae00c4
commit 3ab6820370
3 changed files with 26 additions and 15 deletions

View File

@@ -38,19 +38,23 @@ class FuncNonContiguousArgs:
class OnnxExportTestCase(unittest.TestCase):
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]
MODEL_TO_TEST = [
# (model_name, model_kwargs)
("bert-base-cased", {}),
("gpt2", {"use_cache": False}), # We don't support exporting GPT2 past keys anymore
]
@require_tf
@slow
def test_export_tensorflow(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "tf", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "tf", 12, **model_kwargs)
@require_torch
@slow
def test_export_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
self._test_export(model, "pt", 12, **model_kwargs)
@require_torch
@slow
@@ -71,8 +75,8 @@ class OnnxExportTestCase(unittest.TestCase):
@require_tf
@slow
def test_quantize_tf(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "tf", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "tf", 12, **model_kwargs)
quantized_path = quantize(Path(path))
# Ensure the actual quantized model is not bigger than the original one
@@ -82,15 +86,15 @@ class OnnxExportTestCase(unittest.TestCase):
@require_torch
@slow
def test_quantize_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "pt", 12)
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "pt", 12, **model_kwargs)
quantized_path = quantize(path)
# Ensure the actual quantized model is not bigger than the original one
if quantized_path.stat().st_size >= Path(path).stat().st_size:
self.fail("Quantized model is bigger than initial ONNX model")
def _test_export(self, model, framework, opset, tokenizer=None):
def _test_export(self, model, framework, opset, tokenizer=None, **model_kwargs):
try:
# Compute path
with TemporaryDirectory() as tempdir:
@@ -101,7 +105,7 @@ class OnnxExportTestCase(unittest.TestCase):
path.parent.rmdir()
# Export
convert(framework, model, path, opset, tokenizer)
convert(framework, model, path, opset, tokenizer, **model_kwargs)
return path
except Exception as e: