Onnx fix test (#10663)
* Allow to pass kwargs to model's from_pretrained when using pipeline. * Disable the use of past_keys_values for GPT2 when exporting to ONNX. * style * Remove comment. * Appease the documentation gods * Fix style Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -38,19 +38,23 @@ class FuncNonContiguousArgs:
|
||||
|
||||
|
||||
class OnnxExportTestCase(unittest.TestCase):
|
||||
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"]
|
||||
MODEL_TO_TEST = [
|
||||
# (model_name, model_kwargs)
|
||||
("bert-base-cased", {}),
|
||||
("gpt2", {"use_cache": False}), # We don't support exporting GPT2 past keys anymore
|
||||
]
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_export_tensorflow(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "tf", 12)
|
||||
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "tf", 12, **model_kwargs)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_export_pytorch(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "pt", 12)
|
||||
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "pt", 12, **model_kwargs)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -71,8 +75,8 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
@require_tf
|
||||
@slow
|
||||
def test_quantize_tf(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "tf", 12)
|
||||
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "tf", 12, **model_kwargs)
|
||||
quantized_path = quantize(Path(path))
|
||||
|
||||
# Ensure the actual quantized model is not bigger than the original one
|
||||
@@ -82,15 +86,15 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
@require_torch
|
||||
@slow
|
||||
def test_quantize_pytorch(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "pt", 12)
|
||||
for model, model_kwargs in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "pt", 12, **model_kwargs)
|
||||
quantized_path = quantize(path)
|
||||
|
||||
# Ensure the actual quantized model is not bigger than the original one
|
||||
if quantized_path.stat().st_size >= Path(path).stat().st_size:
|
||||
self.fail("Quantized model is bigger than initial ONNX model")
|
||||
|
||||
def _test_export(self, model, framework, opset, tokenizer=None):
|
||||
def _test_export(self, model, framework, opset, tokenizer=None, **model_kwargs):
|
||||
try:
|
||||
# Compute path
|
||||
with TemporaryDirectory() as tempdir:
|
||||
@@ -101,7 +105,7 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
path.parent.rmdir()
|
||||
|
||||
# Export
|
||||
convert(framework, model, path, opset, tokenizer)
|
||||
convert(framework, model, path, opset, tokenizer, **model_kwargs)
|
||||
|
||||
return path
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user