Files
HuggingFace_transformer/src/transformers/convert_graph_to_onnx.py
Funtowicz Morgan db0076a9df Conversion script to export transformers models to ONNX IR. (#4253)
* Added generic ONNX conversion script for PyTorch model.

* WIP initial TF support.

* TensorFlow/Keras ONNX export working.

* Print framework version info

* Add possibility to check the model is correctly loading on ONNX runtime.

* Remove quantization option.

* Specify ONNX opset version when exporting.

* Formatting.

* Remove unused imports.

* Make functions more generally reusable from other part of the code.

* isort happy.

* flake happy

* Export only feature-extraction for now

* Correctly check inputs order / filter before export.

* Removed task variable

* Fix invalid args call in load_graph_from_args.

* Fix invalid args call in convert.

* Fix invalid args call in infer_shapes.

* Raise exception and catch in caller function instead of exit.

* Add 04-onnx-export.ipynb notebook

* More WIP on the notebook

* Remove unused imports

* Simplify & remove unused constants.

* Export with constant_folding in PyTorch

* Let's try to put function args in the right order this time ...

* Disable external_data_format temporary

* ONNX notebook draft ready.

* Updated notebooks charts + wording

* Correct error while exporting last chart in notebook.

* Adressing @LysandreJik comment.

* Set ONNX opset to 11 as default value.

* Set opset param mandatory

* Added ONNX export unittests

* Quality.

* flake8 happy

* Add keras2onnx dependency on extras["tf"]

* Pin keras2onnx on github master to v1.6.5

* Second attempt.

* Third attempt.

* Use the right repo URL this time ...

* Do the same for onnxconverter-common

* Added keras2onnx and onnxconveter-common to 1.7.0 to supports TF2.2

* Correct commit hash.

* Addressing PR review: Optimization are enabled by default.

* Addressing PR review: small changes in the notebook

* setup.py comment about keras2onnx versioning.
2020-05-14 16:35:52 -04:00

213 lines
7.7 KiB
Python

from argparse import ArgumentParser
from itertools import takewhile
from os import listdir, makedirs
from os.path import abspath, dirname, exists
from typing import Dict, List, Optional, Tuple
from transformers import is_tf_available, is_torch_available
from transformers.pipelines import Pipeline, pipeline
from transformers.tokenization_utils import BatchEncoding
class OnnxConverterArgumentParser(ArgumentParser):
"""
Wraps all the script arguments supported to export transformers models to ONNX IR
"""
def __init__(self):
super(OnnxConverterArgumentParser, self).__init__("ONNX Converter")
self.add_argument("--model", type=str, required=True, help="Model's id or path (ex: bert-base-cased)")
self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: bert-base-cased)")
self.add_argument("--framework", type=str, choices=["pt", "tf"], help="Framework for loading the model")
self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
self.add_argument("--check-loading", action="store_true", help="Check ONNX is able to load the model")
self.add_argument("output")
def ensure_valid_input(model, tokens, input_names):
"""
Ensure input are presented in the correct order, without any None
Args:
model: The model used to forward the input data
tokens: BatchEncoding holding the input data
input_names: The name of the inputs
Returns: Tuple
"""
model_args_name = model.forward.__code__.co_varnames
model_args_pos = [(model_args_name.index(name) - 1, name) for name in input_names]
model_args = [None] * (max(map(lambda x: x[0], model_args_pos)) + 1)
for arg_pos, arg_name in model_args_pos:
model_args[arg_pos] = tokens[arg_name]
model_args = tuple(model_args) # Need to be ordered
return tuple(takewhile(lambda arg: arg is not None, model_args))
def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
def build_shape_dict(tensor, is_input: bool, seq_len: int):
if isinstance(tensor, (tuple, list)):
return [build_shape_dict(t, is_input, seq_len) for t in tensor]
else:
# Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
if is_input:
if len(tensor.shape) == 2:
axes[1] = "sequence"
else:
raise ValueError("Unable to infer tensor axes ({})".format(len(tensor.shape)))
else:
seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
axes.update({dim: "sequence" for dim in seq_axes})
return axes
tokens = nlp.tokenizer.encode_plus("This is a sample output", return_tensors=framework)
seq_len = tokens.input_ids.shape[-1]
outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
if not isinstance(outputs, (list, tuple)):
outputs = (outputs,)
# Generate input names & axes
input_vars = list(tokens.keys())
input_dynamic_axes = {k: build_shape_dict(v, True, seq_len) for k, v in tokens.items()}
# flatten potentially grouped outputs (past for gpt2, attentions)
outputs_flat = []
for output in outputs:
if isinstance(output, (tuple, list)):
outputs_flat.extend(output)
else:
outputs_flat.append(output)
# Generate output names & axes
output_names = ["output_{}".format(i) for i in range(len(outputs_flat))]
output_dynamic_axes = {k: build_shape_dict(v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
# Create the aggregated axes representation
dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
return input_vars, output_names, dynamic_axes, tokens
def load_graph_from_args(framework: str, model: str, tokenizer: Optional[str] = None) -> Pipeline:
# If no tokenizer provided
if tokenizer is None:
tokenizer = model
print("Loading pipeline (model: {}, tokenizer: {})".format(model, tokenizer))
# Allocate tokenizer and model
return pipeline("feature-extraction", model=model, framework=framework)
def convert_pytorch(nlp: Pipeline, opset: int, output: str):
if not is_torch_available():
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
import torch
from torch.onnx import export
print("PyTorch: {}".format(torch.__version__))
with torch.no_grad():
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
model_args = ensure_valid_input(nlp.model, tokens, input_names)
export(
nlp.model,
model_args,
f=output,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
use_external_data_format=True,
enable_onnx_checker=True,
opset_version=opset,
)
def convert_tensorflow(nlp: Pipeline, opset: int, output: str):
if not is_tf_available():
raise Exception(
"Cannot convert {} because TF is not installed. Please install torch first.".format(args.model)
)
print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
try:
import tensorflow as tf
from keras2onnx import convert_keras, save_model, __version__ as k2ov
print("TensorFlow: {}, keras2onnx: {}".format(tf.version.VERSION, k2ov))
# Build
input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
# Forward
nlp.model.predict(tokens.data)
onnx_model = convert_keras(nlp.model, nlp.model.name, target_opset=opset)
save_model(onnx_model, output)
except ImportError as e:
raise Exception(
"Cannot import {} required to convert TF model to ONNX. Please install {} first.".format(e.name, e.name)
)
def convert(framework: str, model: str, output: str, opset: int, tokenizer: Optional[str] = None):
print("ONNX opset version set to: {}".format(opset))
# Load the pipeline
nlp = load_graph_from_args(framework, model, tokenizer)
parent = dirname(output)
if not exists(parent):
print("Creating folder {}".format(parent))
makedirs(parent)
elif len(listdir(parent)) > 0:
raise Exception("Folder {} is not empty, aborting conversion".format(parent))
# Export the graph
if framework == "pt":
convert_pytorch(nlp, opset, output)
else:
convert_tensorflow(nlp, opset, output)
def verify(path: str):
from onnxruntime import InferenceSession, SessionOptions
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
print("Checking ONNX model loading from: {}".format(path))
try:
onnx_options = SessionOptions()
_ = InferenceSession(path, onnx_options, providers=["CPUExecutionProvider"])
print("Model correctly loaded")
except RuntimeException as re:
print("Error while loading the model: {}".format(re))
if __name__ == "__main__":
parser = OnnxConverterArgumentParser()
args = parser.parse_args()
# Make sure output is absolute path
args.output = abspath(args.output)
try:
# Convert
convert(args.framework, args.model, args.output, args.opset, args.tokenizer)
# And verify
if args.check_loading:
verify(args.output)
except Exception as e:
print("Error while converting the model: {}".format(e))
exit(1)