From 2aa3cd935d0f3bcd04ce66be6af4b760493d2ffe Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Thu, 8 Jul 2021 16:54:42 +0200 Subject: [PATCH] [RFC] Laying down building stone for more flexible ONNX export capabilities (#11786) * Laying down building stone for more flexible ONNX export capabilities * Ability to provide a map of config key to override before exporting. * Makes it possible to export BART with/without past keys. * Supports simple mathematical syntax for OnnxVariable.repeated * Effectively apply value override from onnx config for model * Supports export with additional features such as with-past for seq2seq * Store the output path directly in the args for uniform usage across. * Make BART_ONNX_CONFIG_* constants and fix imports. * Support BERT model. * Use tokenizer for more flexibility in defining the inputs of a model. * Add TODO as remainder to provide the batch/sequence_length as CLI args * Enable optimizations to be done on the model. * Enable GPT2 + past * Improve model validation with outputs containing nested structures * Enable Roberta * Enable Albert * Albert requires opset >= 12 * BERT-like models requires opset >= 12 * Remove double printing. * Enable XLM-Roberta * Enable DistilBERT * Disable optimization by default * Fix missing setattr when applying optimizer_features * Add value field to OnnxVariable to define constant input (not from tokenizers) * Add T5 support. * Simplify model type retrieval * Example exporting token_classification pipeline for DistilBERT. * Refactoring to package `transformers.onnx` * Solve circular dependency & __main__ * Remove unnecessary imports in `__init__` * Licences * Use @Narsil's suggestion to forward the model's configuration to the ONNXConfig to avoid interpolation. * Onnx export v2 fixes (#12388) * Tiny fixes Remove `convert_pytorch` from onnxruntime-less runtimes Correct reference to model * Style * Fix Copied from * LongFormer ONNX config. * Removed optimizations * Remvoe bad merge relicas. * Remove unused constants. * Remove some deleted constants from imports. * Fix unittest to remove usage of PyTorch model for onnx.utils. * Fix distilbert export * Enable ONNX export test for supported model. * Style. * Fix lint. * Enable all supported default models. * GPT2 only has one output * Fix bad property name when overriding config. * Added unittests and docstrings. * Disable with_past tests for now. * Enable outputs validation for default export. * Remove graph opt lvls. * Last commit with on-going past commented. * Style. * Disabled `with_past` for now * Remove unused imports. * Remove framework argument * Remove TFPreTrainedModel reference * Add documentation * Add onnxruntime tests to CircleCI * Add test * Rename `convert_pytorch` to `export` * Use OrderedDict for dummy inputs * WIP Wav2Vec2 * Revert "WIP Wav2Vec2" This reverts commit f665efb04c92525c3530e589029f0ae7afdf603e. * Style * Use OrderedDict for I/O * Style. * Specify OrderedDict documentation. * Style :) Co-authored-by: Lysandre Co-authored-by: Lysandre Debut --- .circleci/config.yml | 27 ++ docs/source/serialization.rst | 126 +++++++++ src/transformers/file_utils.py | 35 ++- src/transformers/models/albert/__init__.py | 4 +- .../models/albert/configuration_albert.py | 20 ++ src/transformers/models/bart/__init__.py | 4 +- .../models/bart/configuration_bart.py | 32 +++ src/transformers/models/bert/__init__.py | 4 +- .../models/bert/configuration_bert.py | 19 ++ .../models/distilbert/__init__.py | 12 +- .../distilbert/configuration_distilbert.py | 18 ++ src/transformers/models/gpt2/__init__.py | 4 +- .../models/gpt2/configuration_gpt2.py | 63 +++++ .../models/longformer/__init__.py | 12 +- .../longformer/configuration_longformer.py | 20 +- src/transformers/models/roberta/__init__.py | 4 +- .../models/roberta/configuration_roberta.py | 18 ++ src/transformers/models/t5/__init__.py | 4 +- .../models/t5/configuration_t5.py | 68 +++++ .../models/xlm_roberta/__init__.py | 12 +- .../xlm_roberta/configuration_xlm_roberta.py | 19 ++ src/transformers/onnx/__init__.py | 18 ++ src/transformers/onnx/__main__.py | 150 +++++++++++ src/transformers/onnx/config.py | 223 ++++++++++++++++ src/transformers/onnx/convert.py | 225 ++++++++++++++++ src/transformers/onnx/utils.py | 82 ++++++ src/transformers/testing_utils.py | 8 + tests/test_modeling_tf_common.py | 4 +- tests/test_onnx_v2.py | 251 ++++++++++++++++++ 29 files changed, 1461 insertions(+), 25 deletions(-) create mode 100644 src/transformers/onnx/__init__.py create mode 100644 src/transformers/onnx/__main__.py create mode 100644 src/transformers/onnx/config.py create mode 100644 src/transformers/onnx/convert.py create mode 100644 src/transformers/onnx/utils.py create mode 100644 tests/test_onnx_v2.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 15526cb275..f76343ac66 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -345,6 +345,32 @@ jobs: - '~/.cache/pip' - run: python -m pytest -sv ./tests/ -m is_staging_test + run_tests_onnxruntime: + working_directory: ~/transformers + docker: + - image: circleci/python:3.7 + environment: + OMP_NUM_THREADS: 1 + TRANSFORMERS_IS_CI: yes + resource_class: xlarge + parallelism: 1 + steps: + - checkout + - restore_cache: + keys: + - v0.4-torch-{{ checksum "setup.py" }} + - v0.4-{{ checksum "setup.py" }} + - run: pip install --upgrade pip + - run: pip install .[torch,testing,sentencepiece,onnxruntime] + - save_cache: + key: v0.4-onnx-{{ checksum "setup.py" }} + paths: + - '~/.cache/pip' + - run: python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch ./tests/* -k onnx | tee tests_output.txt + - store_artifacts: + path: ~/transformers/tests_output.txt + - store_artifacts: + path: ~/transformers/reports build_doc: working_directory: ~/transformers docker: @@ -485,6 +511,7 @@ workflows: - run_tests_flax - run_tests_pipelines_torch - run_tests_pipelines_tf + - run_tests_onnxruntime - run_tests_hub - build_doc - deploy_doc: *workflow_filters diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 35fa199b1d..d64c60490a 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -21,11 +21,137 @@ Projects `ONNX (Open Neural Network eXchange) `_ and `ONNXRuntim unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety of hardware and dedicated optimizations. + Starting from transformers v2.10.0 we partnered with ONNX Runtime to provide an easy export of transformers models to the ONNX format. You can have a look at the effort by looking at our joint blog post `Accelerate your NLP pipelines using Hugging Face Transformers and ONNX Runtime `_. + +Configuration-based approach +----------------------------------------------------------------------------------------------------------------------- + +Transformers v4.9.0 introduces a new package: ``transformers.onnx``. This package allows converting checkpoints to an +ONNX graph by leveraging configuration objects. These configuration objects come ready made for a number of model +architectures, and are made to be easily extendable to other architectures. + +Ready-made configurations include the following models: + +- ALBERT +- BART +- BERT +- DistilBERT +- GPT-2 +- RoBERTa +- T5 +- XLM-RoBERTa + +This conversion is handled with the PyTorch version of models - it, therefore, requires PyTorch to be installed. If you +would like to be able to convert from TensorFlow, please let us know by opening an issue. + +.. note:: + The models showcased here are close to fully feature complete, but do lack some features that are currently in + development. Namely, the ability to handle the past key values for decoder models is currently in the works. + + +Converting an ONNX model using the ``transformers.onnx`` package +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The package may be used as a Python module: + +.. code-block:: + + python -m transformers.onnx --help + + usage: Hugging Face ONNX Exporter tool [-h] -m MODEL -f {pytorch} [--features {default}] [--opset OPSET] [--atol ATOL] output + + positional arguments: + output Path indicating where to store generated ONNX model. + + optional arguments: + -h, --help show this help message and exit + -m MODEL, --model MODEL + Model's name of path on disk to load. + -f {pytorch}, --framework {pytorch} + Framework to use when exporting. Possible values are: {'pytorch'} + --features {default} Export the model with some additional features. + --opset OPSET ONNX opset version to export the model with (default 12). + --atol ATOL Absolute difference tolerance when validating the model. + +Exporting a checkpoint using a ready-made configuration can be done as follows: + +.. code-block:: + + python -m transformers.onnx -f pytorch --model=bert-base-cased onnx/bert-base-cased/ + +This exports an ONNX graph of the mentioned checkpoint. Here it is `bert-base-cased`, but it can be any model from the +hub, or a local path. + +It will be exported under ``onnx/bert-base-cased``. You should see similar logs: + +.. code-block:: + + Validating ONNX model... + -[✓] ONNX model outputs' name match reference model ({'pooler_output', 'last_hidden_state'} + - Validating ONNX Model output "last_hidden_state": + -[✓] (2, 8, 768) matchs (2, 8, 768) + -[✓] all values close (atol: 0.0001) + - Validating ONNX Model output "pooler_output": + -[✓] (2, 768) matchs (2, 768) + -[✓] all values close (atol: 0.0001) + All good, model saved at: onnx/bert-base-cased/model.onnx + + +Implementing a custom configuration for an unsupported architecture +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Let's take a look at the changes necessary to add a custom configuration for an unsupported architecture. Firstly, we +will need a custom ONNX configuration object that details the model inputs and outputs. The BERT ONNX configuration is +visible below: + +.. code-block:: + + class BertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ("token_type_ids", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) + +Let's understand what's happening here. This configuration has two properties: the inputs, and the outputs. + +The inputs return a dictionary, where each key corresponds to an expected input, and each value indicates the axis of +that input. + +For BERT, there are three necessary inputs. These three inputs are of similar shape, which is made up of two +dimensions: the batch is the first dimension, and the second is the sequence. + +The outputs return a similar dictionary, where, once again, each key corresponds to an expected output, and each value +indicates the axis of that output. + +Once this is done, a single step remains: adding this configuration object to the initialisation of the model class, +and to the general ``transformers`` initialisation. + +An important fact to notice is the use of `OrderedDict` in both inputs and outputs properties. This is a requirements +as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are +match against there position in the returned `BaseModelOutputX` instance. + + +Graph conversion +----------------------------------------------------------------------------------------------------------------------- + +.. note:: + The approach detailed here is bing deprecated. We recommend you follow the part above for an up to date approach. + + Exporting a model is done through the script `convert_graph_to_onnx.py` at the root of the transformers sources. The following command shows how easy it is to export a BERT model from the library, simply run: diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 021d2fe66e..8e7ee5a3f8 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -148,9 +148,30 @@ except importlib_metadata.PackageNotFoundError: _faiss_available = False -_onnx_available = ( - importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None -) +coloredlogs = importlib.util.find_spec("coloredlogs") is not None +try: + _coloredlogs_available = importlib_metadata.version("coloredlogs") + logger.debug(f"Successfully imported sympy version {_coloredlogs_available}") +except importlib_metadata.PackageNotFoundError: + _coloredlogs_available = False + + +sympy_available = importlib.util.find_spec("sympy") is not None +try: + _sympy_available = importlib_metadata.version("sympy") + logger.debug(f"Successfully imported sympy version {_sympy_available}") +except importlib_metadata.PackageNotFoundError: + _sympy_available = False + + +_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None +try: + _keras2onnx_version = importlib_metadata.version("keras2onnx") + logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}") +except importlib_metadata.PackageNotFoundError: + _keras2onnx_available = False + +_onnx_available = importlib.util.find_spec("onnxruntime") is not None try: _onxx_version = importlib_metadata.version("onnx") logger.debug(f"Successfully imported onnx version {_onxx_version}") @@ -292,6 +313,14 @@ def is_tf_available(): return _tf_available +def is_coloredlogs_available(): + return _coloredlogs_available + + +def is_keras2onnx_available(): + return _keras2onnx_available + + def is_onnx_available(): return _onnx_available diff --git a/src/transformers/models/albert/__init__.py b/src/transformers/models/albert/__init__.py index a358918ae0..44e43eb30b 100644 --- a/src/transformers/models/albert/__init__.py +++ b/src/transformers/models/albert/__init__.py @@ -28,7 +28,7 @@ from ...file_utils import ( _import_structure = { - "configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], + "configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig", "AlbertOnnxConfig"], } if is_sentencepiece_available(): @@ -67,7 +67,7 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig + from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig if is_sentencepiece_available(): from .tokenization_albert import AlbertTokenizer diff --git a/src/transformers/models/albert/configuration_albert.py b/src/transformers/models/albert/configuration_albert.py index f69b87ba6d..2bf3171d0d 100644 --- a/src/transformers/models/albert/configuration_albert.py +++ b/src/transformers/models/albert/configuration_albert.py @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ ALBERT model configuration """ +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -151,3 +154,20 @@ class AlbertConfig(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.classifier_dropout_prob = classifier_dropout_prob self.position_embedding_type = position_embedding_type + + +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert +class AlbertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ("token_type_ids", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py index 89d3584440..a8ddcecc41 100644 --- a/src/transformers/models/bart/__init__.py +++ b/src/transformers/models/bart/__init__.py @@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to _import_structure = { - "configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig"], + "configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig", "BartOnnxConfig"], "tokenization_bart": ["BartTokenizer"], } @@ -53,7 +53,7 @@ if is_flax_available(): ] if TYPE_CHECKING: - from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig + from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, BartOnnxConfig from .tokenization_bart import BartTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 259beda019..3890a9c803 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -14,8 +14,11 @@ # limitations under the License. """ BART model configuration """ import warnings +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -186,3 +189,32 @@ class BartConfig(PretrainedConfig): @property def hidden_size(self) -> int: return self.d_model + + +class BartOnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + else: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/models/bert/__init__.py b/src/transformers/models/bert/__init__.py index 81978cf1b7..9bcf372282 100644 --- a/src/transformers/models/bert/__init__.py +++ b/src/transformers/models/bert/__init__.py @@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to _import_structure = { - "configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig"], + "configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"], "tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"], } @@ -77,7 +77,7 @@ if is_flax_available(): ] if TYPE_CHECKING: - from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig + from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 5555704858..92e989c803 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -14,8 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ BERT model configuration """ +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -154,3 +157,19 @@ class BertConfig(PretrainedConfig): self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache + + +class BertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ("token_type_ids", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/distilbert/__init__.py b/src/transformers/models/distilbert/__init__.py index 4c00e7b2fe..752a31e0be 100644 --- a/src/transformers/models/distilbert/__init__.py +++ b/src/transformers/models/distilbert/__init__.py @@ -22,7 +22,11 @@ from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, _import_structure = { - "configuration_distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig"], + "configuration_distilbert": [ + "DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "DistilBertConfig", + "DistilBertOnnxConfig", + ], "tokenization_distilbert": ["DistilBertTokenizer"], } @@ -56,7 +60,11 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig + from .configuration_distilbert import ( + DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + DistilBertConfig, + DistilBertOnnxConfig, + ) from .tokenization_distilbert import DistilBertTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/distilbert/configuration_distilbert.py b/src/transformers/models/distilbert/configuration_distilbert.py index df561b6516..a171ea1dca 100644 --- a/src/transformers/models/distilbert/configuration_distilbert.py +++ b/src/transformers/models/distilbert/configuration_distilbert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ DistilBERT model configuration """ +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -135,3 +138,18 @@ class DistilBertConfig(PretrainedConfig): @property def num_hidden_layers(self): return self.n_layers + + +class DistilBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})]) diff --git a/src/transformers/models/gpt2/__init__.py b/src/transformers/models/gpt2/__init__.py index a68eb6062f..1df4f9daf6 100644 --- a/src/transformers/models/gpt2/__init__.py +++ b/src/transformers/models/gpt2/__init__.py @@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to _import_structure = { - "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config"], + "configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"], "tokenization_gpt2": ["GPT2Tokenizer"], } @@ -55,7 +55,7 @@ if is_flax_available(): _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"] if TYPE_CHECKING: - from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config + from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig from .tokenization_gpt2 import GPT2Tokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index 00d7b88a4f..503199e4cf 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -14,8 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ OpenAI GPT-2 configuration """ +from collections import OrderedDict +from typing import Any, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -195,3 +200,61 @@ class GPT2Config(PretrainedConfig): @property def num_hidden_layers(self): return self.n_layer + + +class GPT2OnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch"}}) + if self.use_past: + for i in range(self._config.n_layer * 2): + common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"} + + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}) + if self.use_past: + for i in range(self._config.n_layer * 2): + common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"} + + return common_outputs + + return common_outputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch = common_inputs["input_ids"].shape[0] + ordered_inputs["past_key_values"] = [ + ( + torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)), + torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)), + ) + for _ in range(self._config.n_layer) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + return ordered_inputs diff --git a/src/transformers/models/longformer/__init__.py b/src/transformers/models/longformer/__init__.py index c2430668a7..959887f479 100644 --- a/src/transformers/models/longformer/__init__.py +++ b/src/transformers/models/longformer/__init__.py @@ -22,7 +22,11 @@ from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, _import_structure = { - "configuration_longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig"], + "configuration_longformer": [ + "LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", + "LongformerConfig", + "LongformerOnnxConfig", + ], "tokenization_longformer": ["LongformerTokenizer"], } @@ -57,7 +61,11 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig + from .configuration_longformer import ( + LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, + LongformerConfig, + LongformerOnnxConfig, + ) from .tokenization_longformer import LongformerTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/longformer/configuration_longformer.py b/src/transformers/models/longformer/configuration_longformer.py index 3efd5781d2..3c72fc2763 100644 --- a/src/transformers/models/longformer/configuration_longformer.py +++ b/src/transformers/models/longformer/configuration_longformer.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Longformer configuration """ +from collections import OrderedDict +from typing import List, Mapping, Union -from typing import List, Union - +from ...onnx import OnnxConfig from ...utils import logging from ..roberta.configuration_roberta import RobertaConfig @@ -69,3 +70,18 @@ class LongformerConfig(RobertaConfig): def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs): super().__init__(sep_token_id=sep_token_id, **kwargs) self.attention_window = attention_window + + +class LongformerOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py index e9efcfc67d..e76597f3b9 100644 --- a/src/transformers/models/roberta/__init__.py +++ b/src/transformers/models/roberta/__init__.py @@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to _import_structure = { - "configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig"], + "configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaOnnxConfig"], "tokenization_roberta": ["RobertaTokenizer"], } @@ -68,7 +68,7 @@ if is_flax_available(): if TYPE_CHECKING: - from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig + from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaOnnxConfig from .tokenization_roberta import RobertaTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/roberta/configuration_roberta.py b/src/transformers/models/roberta/configuration_roberta.py index 14598a305f..25fc855bd4 100644 --- a/src/transformers/models/roberta/configuration_roberta.py +++ b/src/transformers/models/roberta/configuration_roberta.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ RoBERTa configuration """ +from collections import OrderedDict +from typing import Mapping +from ...onnx import OnnxConfig from ...utils import logging from ..bert.configuration_bert import BertConfig @@ -62,3 +65,18 @@ class RobertaConfig(BertConfig): def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs): """Constructs RobertaConfig.""" super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class RobertaOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/models/t5/__init__.py b/src/transformers/models/t5/__init__.py index cd03159570..0b6e8f8ac4 100644 --- a/src/transformers/models/t5/__init__.py +++ b/src/transformers/models/t5/__init__.py @@ -29,7 +29,7 @@ from ...file_utils import ( _import_structure = { - "configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], + "configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"], } if is_sentencepiece_available(): @@ -66,7 +66,7 @@ if is_flax_available(): if TYPE_CHECKING: - from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config + from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config, T5OnnxConfig if is_sentencepiece_available(): from .tokenization_t5 import T5Tokenizer diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 1e52a0a317..5a6feb5d8e 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """ T5 model configuration """ +from collections import OrderedDict +from typing import Any, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfigWithPast from ...utils import logging @@ -132,3 +137,66 @@ class T5Config(PretrainedConfig): @property def num_hidden_layers(self): return self.num_layers + + +class T5OnnxConfig(OnnxConfigWithPast): + def __init__(self, config: PretrainedConfig, use_past: bool = False): + super().__init__(config, use_past) + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch"}), + ("decoder_attention_mask", {0: "batch"}), + ] + ) + + if self.use_past: + for i in range(self._config.num_layers): + common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},) + common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},) + common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},) + common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},) + + return common_inputs + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + common_outputs = OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "decoder_sequence"}), + ("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}), + ] + ) + + if self.use_past: + for i in range(self._config.num_layers): + common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},) + common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},) + common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},) + common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},) + + return common_outputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.use_past: + raise NotImplementedError() + + # Generate encoder inputs + encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) + + # Generate decoder inputs + decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + + return dict(**encoder_inputs, **decoder_inputs) diff --git a/src/transformers/models/xlm_roberta/__init__.py b/src/transformers/models/xlm_roberta/__init__.py index 9cc49ab315..7ef5dd2c03 100644 --- a/src/transformers/models/xlm_roberta/__init__.py +++ b/src/transformers/models/xlm_roberta/__init__.py @@ -28,7 +28,11 @@ from ...file_utils import ( _import_structure = { - "configuration_xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], + "configuration_xlm_roberta": [ + "XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", + "XLMRobertaConfig", + "XLMRobertaOnnxConfig", + ], } if is_sentencepiece_available(): @@ -62,7 +66,11 @@ if is_tf_available(): if TYPE_CHECKING: - from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig + from .configuration_xlm_roberta import ( + XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, + XLMRobertaConfig, + XLMRobertaOnnxConfig, + ) if is_sentencepiece_available(): from .tokenization_xlm_roberta import XLMRobertaTokenizer diff --git a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py index 2ca58306c0..9300bfcc79 100644 --- a/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/configuration_xlm_roberta.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ XLM-RoBERTa configuration """ +from collections import OrderedDict +from typing import Mapping +from ...onnx import OnnxConfig from ...utils import logging from ..roberta.configuration_roberta import RobertaConfig @@ -38,3 +41,19 @@ class XLMRobertaConfig(RobertaConfig): """ model_type = "xlm-roberta" + + +# Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->XLMRoberta +class XLMRobertaOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})]) diff --git a/src/transformers/onnx/__init__.py b/src/transformers/onnx/__init__.py new file mode 100644 index 0000000000..a61e475b82 --- /dev/null +++ b/src/transformers/onnx/__init__.py @@ -0,0 +1,18 @@ +# flake8: noqa +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast +from .convert import export, validate_model_outputs +from .utils import ParameterFormat, compute_serialized_parameters_size diff --git a/src/transformers/onnx/__main__.py b/src/transformers/onnx/__main__.py new file mode 100644 index 0000000000..2c7b2a6952 --- /dev/null +++ b/src/transformers/onnx/__main__.py @@ -0,0 +1,150 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from pathlib import Path +from typing import Callable, Tuple + +from transformers.models.albert import AlbertOnnxConfig +from transformers.models.auto import AutoTokenizer +from transformers.models.bart import BartOnnxConfig +from transformers.models.bert import BertOnnxConfig +from transformers.models.distilbert import DistilBertOnnxConfig +from transformers.models.gpt2 import GPT2OnnxConfig +from transformers.models.longformer import LongformerOnnxConfig +from transformers.models.roberta import RobertaOnnxConfig +from transformers.models.t5 import T5OnnxConfig +from transformers.models.xlm_roberta import XLMRobertaOnnxConfig + +from .. import is_torch_available +from ..utils import logging +from .convert import export, validate_model_outputs + + +if is_torch_available(): + from transformers import AutoModel, PreTrainedModel + + FEATURES_TO_AUTOMODELS = { + "default": AutoModel, + } + + +# Set of model topologies we support associated to the features supported by each topology and the factory +SUPPORTED_MODEL_KIND = { + "albert": {"default": AlbertOnnxConfig.default}, + "bart": {"default": BartOnnxConfig.default}, + "bert": {"default": BertOnnxConfig.default}, + "distilbert": {"default": DistilBertOnnxConfig.default}, + "gpt2": {"default": GPT2OnnxConfig.default}, + "longformer": {"default": LongformerOnnxConfig.default}, + "roberta": {"default": RobertaOnnxConfig}, + "t5": {"default": T5OnnxConfig.default}, + "xlm-roberta": {"default": XLMRobertaOnnxConfig.default}, +} + + +def get_model_from_features(features: str, model: str): + """ + Attempt to retrieve a model from a model's name and the features to be enabled. + + Args: + features: The features required + model: The name of the model to export + + Returns: + + """ + if features not in FEATURES_TO_AUTOMODELS: + raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}") + + return FEATURES_TO_AUTOMODELS[features].from_pretrained(model) + + +def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]: + """ + Check whether or not the model has the requested features + + Args: + model: The model to export + features: The name of the features to check if they are avaiable + + Returns: + (str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties + + """ + if model.config.model_type not in SUPPORTED_MODEL_KIND: + raise KeyError( + f"{model.config.model_type} ({model.name}) is not supported yet. " + f"Only {SUPPORTED_MODEL_KIND} are supported. " + f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue." + ) + + # Look for the features + model_features = SUPPORTED_MODEL_KIND[model.config.model_type] + if features not in model_features: + raise ValueError( + f"{model.config.model_type} doesn't support features {features}. " + f"Supported values are: {list(model_features.keys())}" + ) + + return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features] + + +def main(): + parser = ArgumentParser("Hugging Face ONNX Exporter tool") + parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.") + parser.add_argument( + "--features", + choices=["default"], + default="default", + help="Export the model with some additional features.", + ) + parser.add_argument( + "--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)." + ) + parser.add_argument( + "--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model." + ) + parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.") + + # Retrieve CLI arguments + args = parser.parse_args() + args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx") + + if not args.output.parent.exists(): + args.output.parent.mkdir(parents=True) + + # Allocate the model + tokenizer = AutoTokenizer.from_pretrained(args.model) + model = get_model_from_features(args.features, args.model) + model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features) + onnx_config = model_onnx_config(model.config) + + # Ensure the requested opset is sufficient + if args.opset < onnx_config.default_onnx_opset: + raise ValueError( + f"Opset {args.opset} is not sufficient to export {model_kind}. " + f"At least {onnx_config.default_onnx_opset} is required." + ) + + onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output) + + validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol) + logger.info(f"All good, model saved at: {args.output.as_posix()}") + + +if __name__ == "__main__": + logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name + logger.setLevel(logging.INFO) + main() diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py new file mode 100644 index 0000000000..2cf11368a0 --- /dev/null +++ b/src/transformers/onnx/config.py @@ -0,0 +1,223 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Any, Mapping, Optional + +from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType + +from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size + + +DEFAULT_ONNX_OPSET = 11 + +# 2 Gb +EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024 + + +class OnnxConfig(ABC): + """ + Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format. + """ + + DEFAULT_FIXED_BATCH = 2 + DEFAULT_FIXED_SEQUENCE = 8 + + def __init__(self, config: PretrainedConfig): + self._config = config + + @classmethod + def default(cls, config: PretrainedConfig) -> "OnnxConfig": + """ + Instantiate a OnnxConfig for a specific model + + Args: + config: The model's configuration to use when exporting to ONNX + + Returns: + OnnxConfig for this model + """ + return cls(config) + + @property + @abstractmethod + def inputs(self) -> Mapping[str, Mapping[int, str]]: + """ + Mapping containing the axis definition of the input tensors to provide to the model + + Returns: + For each input: its name associated to the axes symbolic name and the axis position within the tensor + """ + raise NotImplementedError() + + @property + @abstractmethod + def outputs(self) -> Mapping[str, Mapping[int, str]]: + """ + Mapping containing the axis definition of the output tensors to provide to the model + + Returns: + For each output: its name associated to the axes symbolic name and the axis position within the tensor + """ + raise NotImplementedError() + + @property + def values_override(self) -> Optional[Mapping[str, Any]]: + """ + Dictionary of keys to override in the model's config before exporting + + Returns: + Dictionary with the keys (and their corresponding values) to override + """ + if hasattr(self._config, "use_cache"): + return {"use_cache": False} + + return None + + @property + def default_batch_size(self) -> int: + """ + The default batch size to use if no other indication + + Returns: + Integer > 0 + """ + # Using 2 avoid ONNX making assumption about single sample batch + return OnnxConfig.DEFAULT_FIXED_BATCH + + @property + def default_sequence_length(self) -> int: + """ + The default sequence length to use if no other indication + + Returns: + Integer > 0 + """ + return OnnxConfig.DEFAULT_FIXED_SEQUENCE + + @property + def default_onnx_opset(self) -> int: + """ + Which onnx opset to use when exporting the model + + Returns: + Integer ONNX Opset version + """ + return DEFAULT_ONNX_OPSET + + @staticmethod + def use_external_data_format(num_parameters: int) -> bool: + """ + Flag indicating if the model requires using external data format + + Args: + num_parameters: Number of parameter on the model + + Returns: + True if model.num_parameters() * size_of(float32) >= 2Gb False otherwise + """ + + return ( + compute_serialized_parameters_size(num_parameters, ParameterFormat.Float) + >= EXTERNAL_DATA_FORMAT_SIZE_LIMIT + ) + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + """ + Generate inputs to provide to the ONNX exporter for the specific framework + + Args: + tokenizer: The tokenizer associated with this model configuration + batch_size: The batch size (int) to export the model for (-1 means dynamic axis) + seq_length: The sequence length (int) to export the model for (-1 means dynamic axis) + is_pair: Indicate if the input is a pair (sentence 1, sentence 2) + framework: The framework (optional) the tokenizer will generate tensor for + + Returns: + Mapping[str, Tensor] holding the kwargs to provide to the model's forward function + """ + + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + return dict(tokenizer(dummy_input, return_tensors=framework)) + + +class OnnxConfigWithPast(OnnxConfig, ABC): + def __init__(self, config: PretrainedConfig, use_past: bool = False): + super().__init__(config) + self.use_past = use_past + + @classmethod + def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast": + """ + Instantiate a OnnxConfig with `use_past` attribute set to True + + Args: + config: The underlying model's config to use when exporting to ONNX + + Returns: + OnnxConfig with `.use_past = True` + """ + return cls(config, use_past=True) + + @property + def values_override(self) -> Optional[Mapping[str, Any]]: + if hasattr(self._config, "use_cache"): + return {"use_cache": self.use_past} + + return None + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + + # When use_past the caching mechanism requires inputs to be only 1 single token + fixed_sequence_length = 1 if self.use_past else self.default_sequence_length + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework))) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py new file mode 100644 index 0000000000..e844392feb --- /dev/null +++ b/src/transformers/onnx/convert.py @@ -0,0 +1,225 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from inspect import signature +from itertools import chain +from pathlib import Path +from typing import Iterable, List, Tuple, Union + +import numpy as np +from packaging.version import Version, parse + +from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available +from ..utils import logging +from .config import OnnxConfig +from .utils import flatten_output_collection_property + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# This is the minimal required version to support some ONNX Runtime features +ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0") + + +def check_onnxruntime_requirements(minimum_version: Version): + """ + Check onnxruntime is installed and if the installed version match is recent enough + + Raises: + ImportError: If onnxruntime is not installed or too old version is found + """ + try: + import onnxruntime + + # Parse the version of the installed onnxruntime + ort_version = parse(onnxruntime.__version__) + + # We require 1.4.0 minimum + if ort_version < ORT_QUANTIZE_MINIMUM_VERSION: + raise ImportError( + f"We found an older version of onnxruntime ({onnxruntime.__version__}) " + f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n" + f"Please update onnxruntime by running `pip install --upgrade onnxruntime`" + ) + + except ImportError: + raise ImportError( + "onnxruntime doesn't seem to be currently installed. " + "Please install the onnxruntime by running `pip install onnxruntime`" + " and relaunch the conversion." + ) + + +def export( + tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path +) -> Tuple[List[str], List[str]]: + """ + Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR + + Args: + tokenizer: + model: + config: + opset: + output: + + Returns: + + """ + if not is_torch_available(): + raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") + + import torch + from torch.onnx import export + + logger.info(f"Using framework PyTorch: {torch.__version__}") + torch.set_grad_enabled(False) + model.config.return_dict = True + + # Check if we need to override certain configuration item + if config.values_override is not None: + logger.info(f"Overriding {len(config.values_override)} configuration item(s)") + for override_config_key, override_config_value in config.values_override.items(): + logger.info(f"\t- {override_config_key} -> {override_config_value}") + setattr(model.config, override_config_key, override_config_value) + + # Ensure inputs match + # TODO: Check when exporting QA we provide "is_pair=True" + model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) + onnx_outputs = list(config.outputs.keys()) + + if not inputs_match: + raise ValueError("Model and config inputs doesn't match") + + # export can works with named args but the dict containing named args as to be last element of the args tuple + export( + model, + (model_inputs,), + f=output.as_posix(), + input_names=list(config.inputs.keys()), + output_names=onnx_outputs, + dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, + do_constant_folding=True, + use_external_data_format=config.use_external_data_format(model.num_parameters()), + enable_onnx_checker=True, + opset_version=opset, + ) + + return matched_inputs, onnx_outputs + + +def validate_model_outputs( + config: OnnxConfig, + tokenizer: PreTrainedTokenizer, + reference_model: Union[PreTrainedModel, TFPreTrainedModel], + onnx_model: Path, + onnx_named_outputs: List[str], + atol: float, +): + from onnxruntime import InferenceSession, SessionOptions + + logger.info("Validating ONNX model...") + + reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) + + # Create ONNX Runtime session + options = SessionOptions() + session = InferenceSession(onnx_model.as_posix(), options) + + # Compute outputs from the reference model + ref_outputs = reference_model(**reference_model_inputs) + ref_outputs_dict = {} + + # We flatten potential collection of outputs (i.e. past_keys) to a flat structure + for name, value in ref_outputs.items(): + if isinstance(value, (list, tuple)): + value = flatten_output_collection_property(name, value) + ref_outputs_dict.update(value) + else: + ref_outputs_dict[name] = value + + # We flatten potential collection of inputs (i.e. past_keys) + onnx_inputs = {} + for name, value in reference_model_inputs.items(): + if isinstance(value, (list, tuple)): + value = flatten_output_collection_property(name, value) + onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) + else: + onnx_inputs[name] = value.numpy() + + # Compute outputs from the ONNX model + onnx_outputs = session.run(onnx_named_outputs, onnx_inputs) + + # Check we have a subset of the keys into onnx_outputs against ref_outputs + ref_outputs_set, onnx_outputs_set = set(ref_outputs_dict.keys()), set(onnx_named_outputs) + if not onnx_outputs_set.issubset(ref_outputs_set): + logger.info( + f"\t-[x] ONNX model outputs' name {onnx_outputs_set} doesn't match reference model {ref_outputs_set}" + ) + + raise ValueError( + "Outputs doesn't match between reference model and ONNX exported model: " + f"{onnx_outputs_set.difference(ref_outputs_set)}" + ) + else: + logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}") + + # Check the shape and values match + for name, ort_value in zip(onnx_named_outputs, onnx_outputs): + ref_value = ref_outputs_dict[name].numpy() + logger.info(f'\t- Validating ONNX Model output "{name}":') + + # Shape + if not ort_value.shape == ref_value.shape: + logger.info(f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}") + raise ValueError( + "Outputs shape doesn't match between reference model and ONNX exported model: " + f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)" + ) + else: + logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}") + + # Values + if not np.allclose(ref_value, ort_value, atol=atol): + logger.info(f"\t\t-[x] values not close enough (atol: {atol})") + raise ValueError( + "Outputs values doesn't match between reference model and ONNX exported model: " + f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}" + ) + else: + logger.info(f"\t\t-[✓] all values close (atol: {atol})") + + +def ensure_model_and_config_inputs_match( + model: Union[PreTrainedModel, TFPreTrainedModel], model_inputs: Iterable[str] +) -> Tuple[bool, List[str]]: + """ + + :param model_inputs: + :param config_inputs: + :return: + """ + forward_parameters = signature(model.forward).parameters + model_inputs_set = set(model_inputs) + + # We are fine if config_inputs has more keys than model_inputs + forward_inputs_set = set(forward_parameters.keys()) + is_ok = model_inputs_set.issubset(forward_inputs_set) + + # Make sure the input order match (VERY IMPORTANT !!!!) + matching_inputs = forward_inputs_set.intersection(model_inputs_set) + ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs] + return is_ok, ordered_inputs diff --git a/src/transformers/onnx/utils.py b/src/transformers/onnx/utils.py new file mode 100644 index 0000000000..b32c99119d --- /dev/null +++ b/src/transformers/onnx/utils.py @@ -0,0 +1,82 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ctypes import c_float, sizeof +from enum import Enum +from typing import Any, Dict, Iterable + + +class ParameterFormat(Enum): + Float = c_float + + @property + def size(self) -> int: + """ + Number of byte required for this data type + + Returns: + Integer > 0 + """ + return sizeof(self.value) + + +def compute_effective_axis_dimension(dimension: int, fixed_dimension: int, num_token_to_add: int = 0) -> int: + """ + + Args: + dimension: + fixed_dimension: + num_token_to_add: + + Returns: + + """ + # < 0 is possible if using a dynamic axis + if dimension <= 0: + dimension = fixed_dimension + + dimension -= num_token_to_add + return dimension + + +def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterFormat) -> int: + """ + Compute the size taken by all the parameters in the given the storage format when serializing the model + + Args: + num_parameters: Number of parameters to be saved + dtype: The data format each parameter will be saved + + Returns: + Size (in byte) taken to save all the parameters + """ + return num_parameters * dtype.size + + +def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]: + """ + Flatten any potential nested structure expanding the name of the field with the index of the element within the + structure. + + Args: + name: The name of the nested structure + field: The structure to, potentially, be flattened + + Returns: + (Dict[str, Any]): Outputs with flattened structure and key mapping this new structure. + + """ + from itertools import chain + + return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))} diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 439cee385d..8fa904a4e6 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -33,6 +33,7 @@ from .file_utils import ( is_datasets_available, is_faiss_available, is_flax_available, + is_keras2onnx_available, is_onnx_available, is_pandas_available, is_rjieba_available, @@ -234,6 +235,13 @@ def require_rjieba(test_case): return test_case +def require_keras2onnx(test_case): + if not is_keras2onnx_available(): + return unittest.skip("test requires keras2onnx")(test_case) + else: + return test_case + + def require_onnx(test_case): if not is_onnx_available(): return unittest.skip("test requires ONNX")(test_case) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 3e7734197e..3c907c7470 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -35,7 +35,7 @@ from transformers.testing_utils import ( _tf_gpu_memory_limit, is_pt_tf_cross_test, is_staging_test, - require_onnx, + require_keras2onnx, require_tf, slow, tooslow, @@ -325,7 +325,7 @@ class TFModelTesterMixin: self.assertEqual(len(incompatible_ops), 0, incompatible_ops) - @require_onnx + @require_keras2onnx @slow def test_onnx_runtime_optimize(self): if not self.test_onnx: diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py new file mode 100644 index 0000000000..a4480c5746 --- /dev/null +++ b/tests/test_onnx_v2.py @@ -0,0 +1,251 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile +from unittest import TestCase +from unittest.mock import patch + +from transformers import ( # LongformerConfig, + AlbertConfig, + AutoTokenizer, + BartConfig, + DistilBertConfig, + GPT2Config, + RobertaConfig, + T5Config, + XLMRobertaConfig, + is_torch_available, +) +from transformers.models.albert import AlbertOnnxConfig +from transformers.models.bart import BartOnnxConfig +from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig +from transformers.models.distilbert import DistilBertOnnxConfig + +# from transformers.models.longformer import LongformerOnnxConfig +from transformers.models.gpt2 import GPT2OnnxConfig +from transformers.models.roberta import RobertaOnnxConfig +from transformers.models.t5 import T5OnnxConfig +from transformers.models.xlm_roberta import XLMRobertaOnnxConfig +from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs +from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast +from transformers.onnx.utils import ( + compute_effective_axis_dimension, + compute_serialized_parameters_size, + flatten_output_collection_property, +) +from transformers.testing_utils import require_onnx, require_torch, slow + + +@require_onnx +class OnnxUtilsTestCaseV2(TestCase): + """ + Cover all the utilities involved to export ONNX models + """ + + def test_compute_effective_axis_dimension(self): + """ + When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1. + We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values + (> 1 to avoid ONNX squeezing the axis). + + This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1 + """ + + # Dynamic axis (batch, no token added by the tokenizer) + self.assertEqual(compute_effective_axis_dimension(-1, fixed_dimension=2, num_token_to_add=0), 2) + + # Static axis (batch, no token added by the tokenizer) + self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=2, num_token_to_add=0), 2) + + # Dynamic axis (sequence, token added by the tokenizer 2 (no pair)) + self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6) + self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6) + + # Dynamic axis (sequence, token added by the tokenizer 3 (pair)) + self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5) + self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5) + + def test_compute_parameters_serialized_size(self): + """ + This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the + parameters for the specified parameter's dtype. + """ + self.assertEqual(compute_serialized_parameters_size(2, ParameterFormat.Float), 2 * ParameterFormat.Float.size) + + def test_flatten_output_collection_property(self): + """ + This test ensures we correctly flatten nested collection such as the one we use when returning past_keys. + past_keys = Tuple[Tuple] + + ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n} + """ + self.assertEqual( + flatten_output_collection_property("past_key", [[0], [1], [2]]), + { + "past_key.0": 0, + "past_key.1": 1, + "past_key.2": 2, + }, + ) + + +class OnnxConfigTestCaseV2(TestCase): + """ + Cover the test for models default. + + Default means no specific features is being enabled on the model. + """ + + @patch.multiple(OnnxConfig, __abstractmethods__=set()) + def test_use_external_data_format(self): + """ + External data format is required only if the serialized size of the parameters if bigger than 2Gb + """ + TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT + + # No parameters + self.assertFalse(OnnxConfig.use_external_data_format(0)) + + # Some parameters + self.assertFalse(OnnxConfig.use_external_data_format(1)) + + # Almost 2Gb parameters + self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size)) + + # Exactly 2Gb parameters + self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT)) + + # More than 2Gb parameters + self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size)) + + +class OnnxConfigWithPastTestCaseV2(TestCase): + """ + Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX) + """ + + SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)} + + @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) + def test_use_past(self): + """ + Ensure the use_past variable is correctly being set + """ + for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS: + with self.subTest(name): + self.assertFalse( + OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past" + ) + + self.assertTrue( + OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past" + ) + + @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) + def test_values_override(self): + """ + Ensure the use_past variable correctly set the `use_cache` value in model's configuration + """ + for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS: + with self.subTest(name): + + # without past + onnx_config_default = OnnxConfigWithPast.default(config()) + self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") + self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present") + self.assertFalse( + onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past" + ) + + # with past + onnx_config_default = OnnxConfigWithPast.with_past(config()) + self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None") + self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present") + self.assertTrue( + onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past" + ) + + +if is_torch_available(): + from transformers import ( + AlbertModel, + BartModel, + BertModel, + DistilBertModel, + GPT2Model, + RobertaModel, + T5Model, + XLMRobertaModel, + ) + + PYTORCH_EXPORT_DEFAULT_MODELS = { + ("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig), + ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), + ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig), + ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig), + ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), + # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), + ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), + ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), + ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), + } + + PYTORCH_EXPORT_WITH_PAST_MODELS = { + # ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), + # ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), + # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig) + } + + +class OnnxExportTestCaseV2(TestCase): + """ + Integration tests ensuring supported models are correctly exported + """ + + @slow + @require_torch + def test_pytorch_export_default(self): + from transformers.onnx import export + + for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: + with self.subTest(name): + self.assertTrue(hasattr(onnx_config_class, "default")) + + tokenizer = AutoTokenizer.from_pretrained(model) + model = model_class(config_class()) + onnx_config = onnx_config_class.default(model.config) + + with NamedTemporaryFile("w") as output: + onnx_inputs, onnx_outputs = export( + tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name) + ) + + try: + validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5) + except ValueError as ve: + self.fail(f"{name} -> {ve}") + + @slow + @require_torch + def test_pytorch_export_with_past(self): + from transformers.onnx import export + + for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS: + with self.subTest(name): + self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()") + + tokenizer = AutoTokenizer.from_pretrained(model) + model = model_class(config_class()) + onnx_config = onnx_config_class.with_past(model.config) + + self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.") + self.assertTrue( + onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()" + ) + + with NamedTemporaryFile("w") as output: + output = Path(output.name) + onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output) + + try: + validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5) + except ValueError as ve: + self.fail(f"{name} -> {ve}")