Fix flaky ONNX tests (#6531)

This commit is contained in:
Funtowicz Morgan
2020-08-17 15:04:35 +02:00
committed by GitHub
parent 39c3b1d9de
commit b41cc0b86a

View File

@@ -1,7 +1,5 @@
import unittest import unittest
from os.path import dirname, exists
from pathlib import Path from pathlib import Path
from shutil import rmtree
from tempfile import NamedTemporaryFile, TemporaryDirectory from tempfile import NamedTemporaryFile, TemporaryDirectory
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
@@ -72,7 +70,7 @@ class OnnxExportTestCase(unittest.TestCase):
def test_quantize_pytorch(self): def test_quantize_pytorch(self):
for model in OnnxExportTestCase.MODEL_TO_TEST: for model in OnnxExportTestCase.MODEL_TO_TEST:
path = self._test_export(model, "pt", 12) path = self._test_export(model, "pt", 12)
quantized_path = quantize(Path(path)) quantized_path = quantize(path)
# Ensure the actual quantized model is not bigger than the original one # Ensure the actual quantized model is not bigger than the original one
if quantized_path.stat().st_size >= Path(path).stat().st_size: if quantized_path.stat().st_size >= Path(path).stat().st_size:
@@ -82,11 +80,11 @@ class OnnxExportTestCase(unittest.TestCase):
try: try:
# Compute path # Compute path
with TemporaryDirectory() as tempdir: with TemporaryDirectory() as tempdir:
path = tempdir + "/model.onnx" path = Path(tempdir).joinpath("model.onnx")
# Remove folder if exists # Remove folder if exists
if exists(dirname(path)): if path.parent.exists():
rmtree(dirname(path)) path.parent.rmdir()
# Export # Export
convert(framework, model, path, opset, tokenizer) convert(framework, model, path, opset, tokenizer)