From 5008e08885b21a4cdc1df8efa4405a335db06128 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 9 Aug 2021 15:51:49 +0200 Subject: [PATCH] Add to ONNX docs (#13048) * Add to ONNX docs * Add MBART example * Update docs/source/serialization.rst Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/serialization.rst | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 728e4b7238..fada8bd98e 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -99,6 +99,30 @@ It will be exported under ``onnx/bert-base-cased``. You should see similar logs: -[✓] all values close (atol: 0.0001) All good, model saved at: onnx/bert-base-cased/model.onnx +This export can now be used in the ONNX inference runtime: + +.. code-block:: + + import onnxruntime as ort + + from transformers import BertTokenizerFast + tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") + + ort_session = ort.InferenceSession("onnx/bert-base-cased/model.onnx") + + inputs = tokenizer("Using BERT in ONNX!", return_tensors="np") + outputs = ort_session.run(["last_hidden_state", "pooler_output"], dict(inputs)) + +The outputs used (:obj:`["last_hidden_state", "pooler_output"]`) can be obtained by taking a look at the ONNX +configuration of each model. For example, for BERT: + +.. code-block:: + + from transformers.models.bert import BertOnnxConfig, BertConfig + + config = BertConfig() + onnx_config = BertOnnxConfig(config) + output_keys = list(onnx_config.outputs.keys()) Implementing a custom configuration for an unsupported architecture ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -142,6 +166,12 @@ An important fact to notice is the use of `OrderedDict` in both inputs and outpu as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are match against there position in the returned `BaseModelOutputX` instance. +An example of such an addition is visible here, for the MBart model: `Making MBART ONNX-convertible +`__ + +If you would like to contribute your addition to the library, we recommend you implement tests. An example of such +tests is visible here: `Adding tests to the MBART ONNX conversion +`__ Graph conversion -----------------------------------------------------------------------------------------------------------------------