Added capability to quantize a model while exporting through ONNX. (#6089)
* Added capability to quantize a model while exporting through ONNX. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> We do not support multiple extensions Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Reformat files Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * More quality Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Ensure test_generate_identified_name compares the same object types Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Added documentation everywhere on ONNX exporter Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Use pathlib.Path instead of plain-old string Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Use f-string everywhere Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Use the correct parameters for black formatting Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Use Python 3 super() style. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Use packaging.version to ensure installed onnxruntime version match requirements Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fixing imports sorting order. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Missing raise(s) Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Added quantization documentation Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix some spelling. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix bad list header format Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
import unittest
|
||||
from os.path import dirname, exists
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
|
||||
from transformers.convert_graph_to_onnx import convert, ensure_valid_input, infer_shapes
|
||||
from transformers.convert_graph_to_onnx import (
|
||||
convert,
|
||||
ensure_valid_input,
|
||||
generate_identified_filename,
|
||||
infer_shapes,
|
||||
quantize,
|
||||
)
|
||||
from transformers.testing_utils import require_tf, require_torch, slow
|
||||
|
||||
|
||||
@@ -25,13 +32,13 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
@slow
|
||||
def test_export_tensorflow(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "tf", 11)
|
||||
self._test_export(model, "tf", 12)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_export_pytorch(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
self._test_export(model, "pt", 11)
|
||||
self._test_export(model, "pt", 12)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
@@ -47,7 +54,29 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
with TemporaryDirectory() as bert_save_dir:
|
||||
model = BertModel(BertConfig(vocab_size=len(vocab)))
|
||||
model.save_pretrained(bert_save_dir)
|
||||
self._test_export(bert_save_dir, "pt", 11, tokenizer)
|
||||
self._test_export(bert_save_dir, "pt", 12, tokenizer)
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_quantize_tf(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "tf", 12)
|
||||
quantized_path = quantize(Path(path))
|
||||
|
||||
# Ensure the actual quantized model is not bigger than the original one
|
||||
if quantized_path.stat().st_size >= Path(path).stat().st_size:
|
||||
self.fail("Quantized model is bigger than initial ONNX model")
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_quantize_pytorch(self):
|
||||
for model in OnnxExportTestCase.MODEL_TO_TEST:
|
||||
path = self._test_export(model, "pt", 12)
|
||||
quantized_path = quantize(Path(path))
|
||||
|
||||
# Ensure the actual quantized model is not bigger than the original one
|
||||
if quantized_path.stat().st_size >= Path(path).stat().st_size:
|
||||
self.fail("Quantized model is bigger than initial ONNX model")
|
||||
|
||||
def _test_export(self, model, framework, opset, tokenizer=None):
|
||||
try:
|
||||
@@ -61,6 +90,8 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
|
||||
# Export
|
||||
convert(framework, model, path, opset, tokenizer)
|
||||
|
||||
return path
|
||||
except Exception as e:
|
||||
self.fail(e)
|
||||
|
||||
@@ -138,3 +169,7 @@ class OnnxExportTestCase(unittest.TestCase):
|
||||
# Should have only "input_ids"
|
||||
self.assertEqual(inputs_args[0], tokens["input_ids"])
|
||||
self.assertEqual(ordered_input_names[0], "input_ids")
|
||||
|
||||
def test_generate_identified_name(self):
|
||||
generated = generate_identified_filename(Path("/home/something/my_fake_model.onnx"), "-test")
|
||||
self.assertEqual("/home/something/my_fake_model-test.onnx", generated.as_posix())
|
||||
|
||||
Reference in New Issue
Block a user