Fix flaky ONNX tests (#6531)
This commit is contained in:
@@ -1,7 +1,5 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from os.path import dirname, exists
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
|
||||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||||
|
|
||||||
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
||||||
@@ -72,7 +70,7 @@ class OnnxExportTestCase(unittest.TestCase):
|
|||||||
def test_quantize_pytorch(self):
|
def test_quantize_pytorch(self):
|
||||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||||
path = self._test_export(model, "pt", 12)
|
path = self._test_export(model, "pt", 12)
|
||||||
quantized_path = quantize(Path(path))
|
quantized_path = quantize(path)
|
||||||
|
|
||||||
# Ensure the actual quantized model is not bigger than the original one
|
# Ensure the actual quantized model is not bigger than the original one
|
||||||
if quantized_path.stat().st_size >= Path(path).stat().st_size:
|
if quantized_path.stat().st_size >= Path(path).stat().st_size:
|
||||||
@@ -82,16 +80,16 @@ class OnnxExportTestCase(unittest.TestCase):
|
|||||||
try:
|
try:
|
||||||
# Compute path
|
# Compute path
|
||||||
with TemporaryDirectory() as tempdir:
|
with TemporaryDirectory() as tempdir:
|
||||||
path = tempdir + "/model.onnx"
|
path = Path(tempdir).joinpath("model.onnx")
|
||||||
|
|
||||||
# Remove folder if exists
|
# Remove folder if exists
|
||||||
if exists(dirname(path)):
|
if path.parent.exists():
|
||||||
rmtree(dirname(path))
|
path.parent.rmdir()
|
||||||
|
|
||||||
# Export
|
# Export
|
||||||
convert(framework, model, path, opset, tokenizer)
|
convert(framework, model, path, opset, tokenizer)
|
||||||
|
|
||||||
return path
|
return path
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(e)
|
self.fail(e)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user