Adding optimizations block from ONNXRuntime. (#4431)
* Adding optimizations block from ONNXRuntime. * Turn off external data format by default for PyTorch export. * Correct the way use_external_format is passed through the cmdline args.
This commit is contained in:
@@ -22,6 +22,7 @@ class OnnxConverterArgumentParser(ArgumentParser):
|
||||
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
|
||||
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
|
||||
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
|
||||
self.add_argument("--use-external-format", action="store_true", help="Allow exporting model >= than 2Gb")
|
||||
self.add_argument("output")
|
||||
|
||||
|
||||
@@ -105,7 +106,7 @@ def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] =
|
||||
return pipeline("feature-extraction", model=model, framework=framework)
|
||||
|
||||
|
||||
def convert_pytorch(nlp: Pipeline, opset: int, output: str):
|
||||
def convert_pytorch(nlp: Pipeline, opset: int, output: str, use_external_format: bool):
|
||||
if not is_torch_available():
|
||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
|
||||
@@ -126,7 +127,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: str):
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=True,
|
||||
use_external_data_format=use_external_format,
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
@@ -160,7 +161,14 @@ def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
|
||||
)
|
||||
|
||||
|
||||
def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None):
|
||||
def convert(
|
||||
framework: str,
|
||||
model: str,
|
||||
output: str,
|
||||
opset: int,
|
||||
tokenizer: Optional[str] = None,
|
||||
use_external_format: bool = False,
|
||||
):
|
||||
print("ONNX opset version set to: {}".format(opset))
|
||||
|
||||
# Load the pipeline
|
||||
@@ -175,7 +183,7 @@ def convert(framework: str, model: str, output: str, opset: int, tokenizer: Opti
|
||||
|
||||
# Export the graph
|
||||
if framework == "pt":
|
||||
convert_pytorch(nlp, opset, output)
|
||||
convert_pytorch(nlp, opset, output, use_external_format)
|
||||
else:
|
||||
convert_tensorflow(nlp, opset, output)
|
||||
|
||||
@@ -202,7 +210,7 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
# Convert
|
||||
convert(args.framework, args.model, args.output, args.opset, args.tokenizer)
|
||||
convert(args.framework, args.model, args.output, args.opset, args.tokenizer, args.use_external_format)
|
||||
|
||||
# And verify
|
||||
if args.check_loading:
|
||||
|
||||
Reference in New Issue
Block a user