[RFC] Laying down building stone for more flexible ONNX export capabilities (#11786)
* Laying down building stone for more flexible ONNX export capabilities * Ability to provide a map of config key to override before exporting. * Makes it possible to export BART with/without past keys. * Supports simple mathematical syntax for OnnxVariable.repeated * Effectively apply value override from onnx config for model * Supports export with additional features such as with-past for seq2seq * Store the output path directly in the args for uniform usage across. * Make BART_ONNX_CONFIG_* constants and fix imports. * Support BERT model. * Use tokenizer for more flexibility in defining the inputs of a model. * Add TODO as remainder to provide the batch/sequence_length as CLI args * Enable optimizations to be done on the model. * Enable GPT2 + past * Improve model validation with outputs containing nested structures * Enable Roberta * Enable Albert * Albert requires opset >= 12 * BERT-like models requires opset >= 12 * Remove double printing. * Enable XLM-Roberta * Enable DistilBERT * Disable optimization by default * Fix missing setattr when applying optimizer_features * Add value field to OnnxVariable to define constant input (not from tokenizers) * Add T5 support. * Simplify model type retrieval * Example exporting token_classification pipeline for DistilBERT. * Refactoring to package `transformers.onnx` * Solve circular dependency & __main__ * Remove unnecessary imports in `__init__` * Licences * Use @Narsil's suggestion to forward the model's configuration to the ONNXConfig to avoid interpolation. * Onnx export v2 fixes (#12388) * Tiny fixes Remove `convert_pytorch` from onnxruntime-less runtimes Correct reference to model * Style * Fix Copied from * LongFormer ONNX config. * Removed optimizations * Remvoe bad merge relicas. * Remove unused constants. * Remove some deleted constants from imports. * Fix unittest to remove usage of PyTorch model for onnx.utils. * Fix distilbert export * Enable ONNX export test for supported model. * Style. * Fix lint. * Enable all supported default models. * GPT2 only has one output * Fix bad property name when overriding config. * Added unittests and docstrings. * Disable with_past tests for now. * Enable outputs validation for default export. * Remove graph opt lvls. * Last commit with on-going past commented. * Style. * Disabled `with_past` for now * Remove unused imports. * Remove framework argument * Remove TFPreTrainedModel reference * Add documentation * Add onnxruntime tests to CircleCI * Add test * Rename `convert_pytorch` to `export` * Use OrderedDict for dummy inputs * WIP Wav2Vec2 * Revert "WIP Wav2Vec2" This reverts commit f665efb04c92525c3530e589029f0ae7afdf603e. * Style * Use OrderedDict for I/O * Style. * Specify OrderedDict documentation. * Style :) Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -345,6 +345,32 @@ jobs:
|
|||||||
- '~/.cache/pip'
|
- '~/.cache/pip'
|
||||||
- run: python -m pytest -sv ./tests/ -m is_staging_test
|
- run: python -m pytest -sv ./tests/ -m is_staging_test
|
||||||
|
|
||||||
|
run_tests_onnxruntime:
|
||||||
|
working_directory: ~/transformers
|
||||||
|
docker:
|
||||||
|
- image: circleci/python:3.7
|
||||||
|
environment:
|
||||||
|
OMP_NUM_THREADS: 1
|
||||||
|
TRANSFORMERS_IS_CI: yes
|
||||||
|
resource_class: xlarge
|
||||||
|
parallelism: 1
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- restore_cache:
|
||||||
|
keys:
|
||||||
|
- v0.4-torch-{{ checksum "setup.py" }}
|
||||||
|
- v0.4-{{ checksum "setup.py" }}
|
||||||
|
- run: pip install --upgrade pip
|
||||||
|
- run: pip install .[torch,testing,sentencepiece,onnxruntime]
|
||||||
|
- save_cache:
|
||||||
|
key: v0.4-onnx-{{ checksum "setup.py" }}
|
||||||
|
paths:
|
||||||
|
- '~/.cache/pip'
|
||||||
|
- run: python -m pytest -n 1 --dist=loadfile -s --make-reports=tests_torch ./tests/* -k onnx | tee tests_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/tests_output.txt
|
||||||
|
- store_artifacts:
|
||||||
|
path: ~/transformers/reports
|
||||||
build_doc:
|
build_doc:
|
||||||
working_directory: ~/transformers
|
working_directory: ~/transformers
|
||||||
docker:
|
docker:
|
||||||
@@ -485,6 +511,7 @@ workflows:
|
|||||||
- run_tests_flax
|
- run_tests_flax
|
||||||
- run_tests_pipelines_torch
|
- run_tests_pipelines_torch
|
||||||
- run_tests_pipelines_tf
|
- run_tests_pipelines_tf
|
||||||
|
- run_tests_onnxruntime
|
||||||
- run_tests_hub
|
- run_tests_hub
|
||||||
- build_doc
|
- build_doc
|
||||||
- deploy_doc: *workflow_filters
|
- deploy_doc: *workflow_filters
|
||||||
|
|||||||
@@ -21,11 +21,137 @@ Projects `ONNX (Open Neural Network eXchange) <http://onnx.ai>`_ and `ONNXRuntim
|
|||||||
unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety
|
unified and community-driven format to store and, by extension, efficiently execute neural network leveraging a variety
|
||||||
of hardware and dedicated optimizations.
|
of hardware and dedicated optimizations.
|
||||||
|
|
||||||
|
|
||||||
Starting from transformers v2.10.0 we partnered with ONNX Runtime to provide an easy export of transformers models to
|
Starting from transformers v2.10.0 we partnered with ONNX Runtime to provide an easy export of transformers models to
|
||||||
the ONNX format. You can have a look at the effort by looking at our joint blog post `Accelerate your NLP pipelines
|
the ONNX format. You can have a look at the effort by looking at our joint blog post `Accelerate your NLP pipelines
|
||||||
using Hugging Face Transformers and ONNX Runtime
|
using Hugging Face Transformers and ONNX Runtime
|
||||||
<https://medium.com/microsoftazure/accelerate-your-nlp-pipelines-using-hugging-face-transformers-and-onnx-runtime-2443578f4333>`_.
|
<https://medium.com/microsoftazure/accelerate-your-nlp-pipelines-using-hugging-face-transformers-and-onnx-runtime-2443578f4333>`_.
|
||||||
|
|
||||||
|
|
||||||
|
Configuration-based approach
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
Transformers v4.9.0 introduces a new package: ``transformers.onnx``. This package allows converting checkpoints to an
|
||||||
|
ONNX graph by leveraging configuration objects. These configuration objects come ready made for a number of model
|
||||||
|
architectures, and are made to be easily extendable to other architectures.
|
||||||
|
|
||||||
|
Ready-made configurations include the following models:
|
||||||
|
|
||||||
|
- ALBERT
|
||||||
|
- BART
|
||||||
|
- BERT
|
||||||
|
- DistilBERT
|
||||||
|
- GPT-2
|
||||||
|
- RoBERTa
|
||||||
|
- T5
|
||||||
|
- XLM-RoBERTa
|
||||||
|
|
||||||
|
This conversion is handled with the PyTorch version of models - it, therefore, requires PyTorch to be installed. If you
|
||||||
|
would like to be able to convert from TensorFlow, please let us know by opening an issue.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The models showcased here are close to fully feature complete, but do lack some features that are currently in
|
||||||
|
development. Namely, the ability to handle the past key values for decoder models is currently in the works.
|
||||||
|
|
||||||
|
|
||||||
|
Converting an ONNX model using the ``transformers.onnx`` package
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
The package may be used as a Python module:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
python -m transformers.onnx --help
|
||||||
|
|
||||||
|
usage: Hugging Face ONNX Exporter tool [-h] -m MODEL -f {pytorch} [--features {default}] [--opset OPSET] [--atol ATOL] output
|
||||||
|
|
||||||
|
positional arguments:
|
||||||
|
output Path indicating where to store generated ONNX model.
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
-m MODEL, --model MODEL
|
||||||
|
Model's name of path on disk to load.
|
||||||
|
-f {pytorch}, --framework {pytorch}
|
||||||
|
Framework to use when exporting. Possible values are: {'pytorch'}
|
||||||
|
--features {default} Export the model with some additional features.
|
||||||
|
--opset OPSET ONNX opset version to export the model with (default 12).
|
||||||
|
--atol ATOL Absolute difference tolerance when validating the model.
|
||||||
|
|
||||||
|
Exporting a checkpoint using a ready-made configuration can be done as follows:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
python -m transformers.onnx -f pytorch --model=bert-base-cased onnx/bert-base-cased/
|
||||||
|
|
||||||
|
This exports an ONNX graph of the mentioned checkpoint. Here it is `bert-base-cased`, but it can be any model from the
|
||||||
|
hub, or a local path.
|
||||||
|
|
||||||
|
It will be exported under ``onnx/bert-base-cased``. You should see similar logs:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Validating ONNX model...
|
||||||
|
-[✓] ONNX model outputs' name match reference model ({'pooler_output', 'last_hidden_state'}
|
||||||
|
- Validating ONNX Model output "last_hidden_state":
|
||||||
|
-[✓] (2, 8, 768) matchs (2, 8, 768)
|
||||||
|
-[✓] all values close (atol: 0.0001)
|
||||||
|
- Validating ONNX Model output "pooler_output":
|
||||||
|
-[✓] (2, 768) matchs (2, 768)
|
||||||
|
-[✓] all values close (atol: 0.0001)
|
||||||
|
All good, model saved at: onnx/bert-base-cased/model.onnx
|
||||||
|
|
||||||
|
|
||||||
|
Implementing a custom configuration for an unsupported architecture
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
Let's take a look at the changes necessary to add a custom configuration for an unsupported architecture. Firstly, we
|
||||||
|
will need a custom ONNX configuration object that details the model inputs and outputs. The BERT ONNX configuration is
|
||||||
|
visible below:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
class BertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||||
|
|
||||||
|
Let's understand what's happening here. This configuration has two properties: the inputs, and the outputs.
|
||||||
|
|
||||||
|
The inputs return a dictionary, where each key corresponds to an expected input, and each value indicates the axis of
|
||||||
|
that input.
|
||||||
|
|
||||||
|
For BERT, there are three necessary inputs. These three inputs are of similar shape, which is made up of two
|
||||||
|
dimensions: the batch is the first dimension, and the second is the sequence.
|
||||||
|
|
||||||
|
The outputs return a similar dictionary, where, once again, each key corresponds to an expected output, and each value
|
||||||
|
indicates the axis of that output.
|
||||||
|
|
||||||
|
Once this is done, a single step remains: adding this configuration object to the initialisation of the model class,
|
||||||
|
and to the general ``transformers`` initialisation.
|
||||||
|
|
||||||
|
An important fact to notice is the use of `OrderedDict` in both inputs and outputs properties. This is a requirements
|
||||||
|
as inputs are matched against their relative position within the `PreTrainedModel.forward()` prototype and outputs are
|
||||||
|
match against there position in the returned `BaseModelOutputX` instance.
|
||||||
|
|
||||||
|
|
||||||
|
Graph conversion
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The approach detailed here is bing deprecated. We recommend you follow the part above for an up to date approach.
|
||||||
|
|
||||||
|
|
||||||
Exporting a model is done through the script `convert_graph_to_onnx.py` at the root of the transformers sources. The
|
Exporting a model is done through the script `convert_graph_to_onnx.py` at the root of the transformers sources. The
|
||||||
following command shows how easy it is to export a BERT model from the library, simply run:
|
following command shows how easy it is to export a BERT model from the library, simply run:
|
||||||
|
|
||||||
|
|||||||
@@ -148,9 +148,30 @@ except importlib_metadata.PackageNotFoundError:
|
|||||||
_faiss_available = False
|
_faiss_available = False
|
||||||
|
|
||||||
|
|
||||||
_onnx_available = (
|
coloredlogs = importlib.util.find_spec("coloredlogs") is not None
|
||||||
importlib.util.find_spec("keras2onnx") is not None and importlib.util.find_spec("onnxruntime") is not None
|
try:
|
||||||
)
|
_coloredlogs_available = importlib_metadata.version("coloredlogs")
|
||||||
|
logger.debug(f"Successfully imported sympy version {_coloredlogs_available}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_coloredlogs_available = False
|
||||||
|
|
||||||
|
|
||||||
|
sympy_available = importlib.util.find_spec("sympy") is not None
|
||||||
|
try:
|
||||||
|
_sympy_available = importlib_metadata.version("sympy")
|
||||||
|
logger.debug(f"Successfully imported sympy version {_sympy_available}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_sympy_available = False
|
||||||
|
|
||||||
|
|
||||||
|
_keras2onnx_available = importlib.util.find_spec("keras2onnx") is not None
|
||||||
|
try:
|
||||||
|
_keras2onnx_version = importlib_metadata.version("keras2onnx")
|
||||||
|
logger.debug(f"Successfully imported keras2onnx version {_keras2onnx_version}")
|
||||||
|
except importlib_metadata.PackageNotFoundError:
|
||||||
|
_keras2onnx_available = False
|
||||||
|
|
||||||
|
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||||
try:
|
try:
|
||||||
_onxx_version = importlib_metadata.version("onnx")
|
_onxx_version = importlib_metadata.version("onnx")
|
||||||
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
logger.debug(f"Successfully imported onnx version {_onxx_version}")
|
||||||
@@ -292,6 +313,14 @@ def is_tf_available():
|
|||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_coloredlogs_available():
|
||||||
|
return _coloredlogs_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_keras2onnx_available():
|
||||||
|
return _keras2onnx_available
|
||||||
|
|
||||||
|
|
||||||
def is_onnx_available():
|
def is_onnx_available():
|
||||||
return _onnx_available
|
return _onnx_available
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from ...file_utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
|
"configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig", "AlbertOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@@ -67,7 +67,7 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
|
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_albert import AlbertTokenizer
|
from .tokenization_albert import AlbertTokenizer
|
||||||
|
|||||||
@@ -14,8 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" ALBERT model configuration """
|
""" ALBERT model configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
|
|
||||||
|
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
@@ -151,3 +154,20 @@ class AlbertConfig(PretrainedConfig):
|
|||||||
self.layer_norm_eps = layer_norm_eps
|
self.layer_norm_eps = layer_norm_eps
|
||||||
self.classifier_dropout_prob = classifier_dropout_prob
|
self.classifier_dropout_prob = classifier_dropout_prob
|
||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Roberta->Albert
|
||||||
|
class AlbertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig"],
|
"configuration_bart": ["BART_PRETRAINED_CONFIG_ARCHIVE_MAP", "BartConfig", "BartOnnxConfig"],
|
||||||
"tokenization_bart": ["BartTokenizer"],
|
"tokenization_bart": ["BartTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ if is_flax_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
|
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig, BartOnnxConfig
|
||||||
from .tokenization_bart import BartTokenizer
|
from .tokenization_bart import BartTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,8 +14,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BART model configuration """
|
""" BART model configuration """
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -186,3 +189,32 @@ class BartConfig(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def hidden_size(self) -> int:
|
def hidden_size(self) -> int:
|
||||||
return self.d_model
|
return self.d_model
|
||||||
|
|
||||||
|
|
||||||
|
class BartOnnxConfig(OnnxConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.use_past:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
("past_keys", {0: "batch", 2: "sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig"],
|
"configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"],
|
||||||
"tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
|
"tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ if is_flax_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
|
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
|
||||||
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,8 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" BERT model configuration """
|
""" BERT model configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -154,3 +157,19 @@ class BertConfig(PretrainedConfig):
|
|||||||
self.gradient_checkpointing = gradient_checkpointing
|
self.gradient_checkpointing = gradient_checkpointing
|
||||||
self.position_embedding_type = position_embedding_type
|
self.position_embedding_type = position_embedding_type
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
|
|
||||||
|
|
||||||
|
class BertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
("token_type_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available,
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig"],
|
"configuration_distilbert": [
|
||||||
|
"DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"DistilBertConfig",
|
||||||
|
"DistilBertOnnxConfig",
|
||||||
|
],
|
||||||
"tokenization_distilbert": ["DistilBertTokenizer"],
|
"tokenization_distilbert": ["DistilBertTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +60,11 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
|
from .configuration_distilbert import (
|
||||||
|
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
DistilBertConfig,
|
||||||
|
DistilBertOnnxConfig,
|
||||||
|
)
|
||||||
from .tokenization_distilbert import DistilBertTokenizer
|
from .tokenization_distilbert import DistilBertTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -13,8 +13,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" DistilBERT model configuration """
|
""" DistilBERT model configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -135,3 +138,18 @@ class DistilBertConfig(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def num_hidden_layers(self):
|
def num_hidden_layers(self):
|
||||||
return self.n_layers
|
return self.n_layers
|
||||||
|
|
||||||
|
|
||||||
|
class DistilBertOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config"],
|
"configuration_gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2OnnxConfig"],
|
||||||
"tokenization_gpt2": ["GPT2Tokenizer"],
|
"tokenization_gpt2": ["GPT2Tokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ if is_flax_available():
|
|||||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
|
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2OnnxConfig
|
||||||
from .tokenization_gpt2 import GPT2Tokenizer
|
from .tokenization_gpt2 import GPT2Tokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,8 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" OpenAI GPT-2 configuration """
|
""" OpenAI GPT-2 configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -195,3 +200,61 @@ class GPT2Config(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def num_hidden_layers(self):
|
def num_hidden_layers(self):
|
||||||
return self.n_layer
|
return self.n_layer
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
|
||||||
|
if self.use_past:
|
||||||
|
for i in range(self._config.n_layer * 2):
|
||||||
|
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}
|
||||||
|
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||||
|
else:
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_outputs = OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}})
|
||||||
|
if self.use_past:
|
||||||
|
for i in range(self._config.n_layer * 2):
|
||||||
|
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}
|
||||||
|
|
||||||
|
return common_outputs
|
||||||
|
|
||||||
|
return common_outputs
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||||
|
|
||||||
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
|
||||||
|
|
||||||
|
# Need to add the past_keys
|
||||||
|
if self.use_past:
|
||||||
|
if not is_torch_available():
|
||||||
|
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||||
|
else:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch = common_inputs["input_ids"].shape[0]
|
||||||
|
ordered_inputs["past_key_values"] = [
|
||||||
|
(
|
||||||
|
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
|
||||||
|
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
|
||||||
|
)
|
||||||
|
for _ in range(self._config.n_layer)
|
||||||
|
]
|
||||||
|
|
||||||
|
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
|
||||||
|
return ordered_inputs
|
||||||
|
|||||||
@@ -22,7 +22,11 @@ from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available,
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig"],
|
"configuration_longformer": [
|
||||||
|
"LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"LongformerConfig",
|
||||||
|
"LongformerOnnxConfig",
|
||||||
|
],
|
||||||
"tokenization_longformer": ["LongformerTokenizer"],
|
"tokenization_longformer": ["LongformerTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +61,11 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
from .configuration_longformer import (
|
||||||
|
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
LongformerConfig,
|
||||||
|
LongformerOnnxConfig,
|
||||||
|
)
|
||||||
from .tokenization_longformer import LongformerTokenizer
|
from .tokenization_longformer import LongformerTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -13,9 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Longformer configuration """
|
""" Longformer configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List, Mapping, Union
|
||||||
|
|
||||||
from typing import List, Union
|
from ...onnx import OnnxConfig
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.configuration_roberta import RobertaConfig
|
from ..roberta.configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
@@ -69,3 +70,18 @@ class LongformerConfig(RobertaConfig):
|
|||||||
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
|
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
|
||||||
super().__init__(sep_token_id=sep_token_id, **kwargs)
|
super().__init__(sep_token_id=sep_token_id, **kwargs)
|
||||||
self.attention_window = attention_window
|
self.attention_window = attention_window
|
||||||
|
|
||||||
|
|
||||||
|
class LongformerOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_to
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig"],
|
"configuration_roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaOnnxConfig"],
|
||||||
"tokenization_roberta": ["RobertaTokenizer"],
|
"tokenization_roberta": ["RobertaTokenizer"],
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
|
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaOnnxConfig
|
||||||
from .tokenization_roberta import RobertaTokenizer
|
from .tokenization_roberta import RobertaTokenizer
|
||||||
|
|
||||||
if is_tokenizers_available():
|
if is_tokenizers_available():
|
||||||
|
|||||||
@@ -14,7 +14,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" RoBERTa configuration """
|
""" RoBERTa configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..bert.configuration_bert import BertConfig
|
from ..bert.configuration_bert import BertConfig
|
||||||
|
|
||||||
@@ -62,3 +65,18 @@ class RobertaConfig(BertConfig):
|
|||||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
||||||
"""Constructs RobertaConfig."""
|
"""Constructs RobertaConfig."""
|
||||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from ...file_utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"],
|
"configuration_t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config", "T5OnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@@ -66,7 +66,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
|
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config, T5OnnxConfig
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_t5 import T5Tokenizer
|
from .tokenization_t5 import T5Tokenizer
|
||||||
|
|||||||
@@ -13,8 +13,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" T5 model configuration """
|
""" T5 model configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfigWithPast
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -132,3 +137,66 @@ class T5Config(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def num_hidden_layers(self):
|
def num_hidden_layers(self):
|
||||||
return self.num_layers
|
return self.num_layers
|
||||||
|
|
||||||
|
|
||||||
|
class T5OnnxConfig(OnnxConfigWithPast):
|
||||||
|
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
||||||
|
super().__init__(config, use_past)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||||
|
("decoder_input_ids", {0: "batch"}),
|
||||||
|
("decoder_attention_mask", {0: "batch"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
for i in range(self._config.num_layers):
|
||||||
|
common_inputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
||||||
|
common_inputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
||||||
|
common_inputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "past_sequence"},)
|
||||||
|
common_inputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "past_sequence"},)
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_outputs = OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "decoder_sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 2: "encoder_sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_past:
|
||||||
|
for i in range(self._config.num_layers):
|
||||||
|
common_outputs[f"past_key_values.{i}.decoder.0"] = ({0: "batch", 2: "decoder_sequence"},)
|
||||||
|
common_outputs[f"past_key_values.{i}.decoder.1"] = ({0: "batch", 2: "decoder_sequence"},)
|
||||||
|
common_outputs[f"past_key_values.{i}.encoder.0"] = ({0: "batch", 2: "encoder_sequence"},)
|
||||||
|
common_outputs[f"past_key_values.{i}.encoder.1"] = ({0: "batch", 2: "encoder_sequence"},)
|
||||||
|
|
||||||
|
return common_outputs
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
if self.use_past:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# Generate encoder inputs
|
||||||
|
encoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
||||||
|
|
||||||
|
# Generate decoder inputs
|
||||||
|
decoder_inputs = super().generate_dummy_inputs(tokenizer, batch_size, 1, is_pair, framework)
|
||||||
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||||
|
|
||||||
|
return dict(**encoder_inputs, **decoder_inputs)
|
||||||
|
|||||||
@@ -28,7 +28,11 @@ from ...file_utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"],
|
"configuration_xlm_roberta": [
|
||||||
|
"XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||||
|
"XLMRobertaConfig",
|
||||||
|
"XLMRobertaOnnxConfig",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@@ -62,7 +66,11 @@ if is_tf_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
|
from .configuration_xlm_roberta import (
|
||||||
|
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
|
XLMRobertaConfig,
|
||||||
|
XLMRobertaOnnxConfig,
|
||||||
|
)
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
from .tokenization_xlm_roberta import XLMRobertaTokenizer
|
||||||
|
|||||||
@@ -14,7 +14,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" XLM-RoBERTa configuration """
|
""" XLM-RoBERTa configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.configuration_roberta import RobertaConfig
|
from ..roberta.configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
@@ -38,3 +41,19 @@ class XLMRobertaConfig(RobertaConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_type = "xlm-roberta"
|
model_type = "xlm-roberta"
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.roberta.configuration_roberta.RobertaOnnxConfig with Roberta->XLMRoberta
|
||||||
|
class XLMRobertaOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"}), ("pooler_output", {0: "batch"})])
|
||||||
|
|||||||
18
src/transformers/onnx/__init__.py
Normal file
18
src/transformers/onnx/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
|
||||||
|
from .convert import export, validate_model_outputs
|
||||||
|
from .utils import ParameterFormat, compute_serialized_parameters_size
|
||||||
150
src/transformers/onnx/__main__.py
Normal file
150
src/transformers/onnx/__main__.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from transformers.models.albert import AlbertOnnxConfig
|
||||||
|
from transformers.models.auto import AutoTokenizer
|
||||||
|
from transformers.models.bart import BartOnnxConfig
|
||||||
|
from transformers.models.bert import BertOnnxConfig
|
||||||
|
from transformers.models.distilbert import DistilBertOnnxConfig
|
||||||
|
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||||
|
from transformers.models.longformer import LongformerOnnxConfig
|
||||||
|
from transformers.models.roberta import RobertaOnnxConfig
|
||||||
|
from transformers.models.t5 import T5OnnxConfig
|
||||||
|
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
|
|
||||||
|
from .. import is_torch_available
|
||||||
|
from ..utils import logging
|
||||||
|
from .convert import export, validate_model_outputs
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import AutoModel, PreTrainedModel
|
||||||
|
|
||||||
|
FEATURES_TO_AUTOMODELS = {
|
||||||
|
"default": AutoModel,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Set of model topologies we support associated to the features supported by each topology and the factory
|
||||||
|
SUPPORTED_MODEL_KIND = {
|
||||||
|
"albert": {"default": AlbertOnnxConfig.default},
|
||||||
|
"bart": {"default": BartOnnxConfig.default},
|
||||||
|
"bert": {"default": BertOnnxConfig.default},
|
||||||
|
"distilbert": {"default": DistilBertOnnxConfig.default},
|
||||||
|
"gpt2": {"default": GPT2OnnxConfig.default},
|
||||||
|
"longformer": {"default": LongformerOnnxConfig.default},
|
||||||
|
"roberta": {"default": RobertaOnnxConfig},
|
||||||
|
"t5": {"default": T5OnnxConfig.default},
|
||||||
|
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_from_features(features: str, model: str):
|
||||||
|
"""
|
||||||
|
Attempt to retrieve a model from a model's name and the features to be enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: The features required
|
||||||
|
model: The name of the model to export
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if features not in FEATURES_TO_AUTOMODELS:
|
||||||
|
raise KeyError(f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}")
|
||||||
|
|
||||||
|
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
|
||||||
|
|
||||||
|
|
||||||
|
def check_supported_model_or_raise(model: PreTrainedModel, features: str = "default") -> Tuple[str, Callable]:
|
||||||
|
"""
|
||||||
|
Check whether or not the model has the requested features
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to export
|
||||||
|
features: The name of the features to check if they are avaiable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(str) The type of the model (OnnxConfig) The OnnxConfig instance holding the model export properties
|
||||||
|
|
||||||
|
"""
|
||||||
|
if model.config.model_type not in SUPPORTED_MODEL_KIND:
|
||||||
|
raise KeyError(
|
||||||
|
f"{model.config.model_type} ({model.name}) is not supported yet. "
|
||||||
|
f"Only {SUPPORTED_MODEL_KIND} are supported. "
|
||||||
|
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Look for the features
|
||||||
|
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
|
||||||
|
if features not in model_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model.config.model_type} doesn't support features {features}. "
|
||||||
|
f"Supported values are: {list(model_features.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
|
||||||
|
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--features",
|
||||||
|
choices=["default"],
|
||||||
|
default="default",
|
||||||
|
help="Export the model with some additional features.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model."
|
||||||
|
)
|
||||||
|
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
|
||||||
|
|
||||||
|
# Retrieve CLI arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")
|
||||||
|
|
||||||
|
if not args.output.parent.exists():
|
||||||
|
args.output.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Allocate the model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
model = get_model_from_features(args.features, args.model)
|
||||||
|
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
|
||||||
|
onnx_config = model_onnx_config(model.config)
|
||||||
|
|
||||||
|
# Ensure the requested opset is sufficient
|
||||||
|
if args.opset < onnx_config.default_onnx_opset:
|
||||||
|
raise ValueError(
|
||||||
|
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
||||||
|
f"At least {onnx_config.default_onnx_opset} is required."
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
|
||||||
|
|
||||||
|
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
|
||||||
|
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger = logging.get_logger("transformers.onnx") # pylint: disable=invalid-name
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
main()
|
||||||
223
src/transformers/onnx/config.py
Normal file
223
src/transformers/onnx/config.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||||
|
|
||||||
|
from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_ONNX_OPSET = 11
|
||||||
|
|
||||||
|
# 2 Gb
|
||||||
|
EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxConfig(ABC):
|
||||||
|
"""
|
||||||
|
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_FIXED_BATCH = 2
|
||||||
|
DEFAULT_FIXED_SEQUENCE = 8
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig):
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls, config: PretrainedConfig) -> "OnnxConfig":
|
||||||
|
"""
|
||||||
|
Instantiate a OnnxConfig for a specific model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The model's configuration to use when exporting to ONNX
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OnnxConfig for this model
|
||||||
|
"""
|
||||||
|
return cls(config)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
"""
|
||||||
|
Mapping containing the axis definition of the input tensors to provide to the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
For each input: its name associated to the axes symbolic name and the axis position within the tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
"""
|
||||||
|
Mapping containing the axis definition of the output tensors to provide to the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
For each output: its name associated to the axes symbolic name and the axis position within the tensor
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||||
|
"""
|
||||||
|
Dictionary of keys to override in the model's config before exporting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with the keys (and their corresponding values) to override
|
||||||
|
"""
|
||||||
|
if hasattr(self._config, "use_cache"):
|
||||||
|
return {"use_cache": False}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_batch_size(self) -> int:
|
||||||
|
"""
|
||||||
|
The default batch size to use if no other indication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer > 0
|
||||||
|
"""
|
||||||
|
# Using 2 avoid ONNX making assumption about single sample batch
|
||||||
|
return OnnxConfig.DEFAULT_FIXED_BATCH
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_sequence_length(self) -> int:
|
||||||
|
"""
|
||||||
|
The default sequence length to use if no other indication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer > 0
|
||||||
|
"""
|
||||||
|
return OnnxConfig.DEFAULT_FIXED_SEQUENCE
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
"""
|
||||||
|
Which onnx opset to use when exporting the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer ONNX Opset version
|
||||||
|
"""
|
||||||
|
return DEFAULT_ONNX_OPSET
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def use_external_data_format(num_parameters: int) -> bool:
|
||||||
|
"""
|
||||||
|
Flag indicating if the model requires using external data format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_parameters: Number of parameter on the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model.num_parameters() * size_of(float32) >= 2Gb False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (
|
||||||
|
compute_serialized_parameters_size(num_parameters, ParameterFormat.Float)
|
||||||
|
>= EXTERNAL_DATA_FORMAT_SIZE_LIMIT
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate inputs to provide to the ONNX exporter for the specific framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: The tokenizer associated with this model configuration
|
||||||
|
batch_size: The batch size (int) to export the model for (-1 means dynamic axis)
|
||||||
|
seq_length: The sequence length (int) to export the model for (-1 means dynamic axis)
|
||||||
|
is_pair: Indicate if the input is a pair (sentence 1, sentence 2)
|
||||||
|
framework: The framework (optional) the tokenizer will generate tensor for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
|
return dict(tokenizer(dummy_input, return_tensors=framework))
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||||
|
def __init__(self, config: PretrainedConfig, use_past: bool = False):
|
||||||
|
super().__init__(config)
|
||||||
|
self.use_past = use_past
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def with_past(cls, config: PretrainedConfig) -> "OnnxConfigWithPast":
|
||||||
|
"""
|
||||||
|
Instantiate a OnnxConfig with `use_past` attribute set to True
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The underlying model's config to use when exporting to ONNX
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OnnxConfig with `.use_past = True`
|
||||||
|
"""
|
||||||
|
return cls(config, use_past=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def values_override(self) -> Optional[Mapping[str, Any]]:
|
||||||
|
if hasattr(self._config, "use_cache"):
|
||||||
|
return {"use_cache": self.use_past}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
|
|
||||||
|
# When use_past the caching mechanism requires inputs to be only 1 single token
|
||||||
|
fixed_sequence_length = 1 if self.use_past else self.default_sequence_length
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||||
|
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))
|
||||||
225
src/transformers/onnx/convert.py
Normal file
225
src/transformers/onnx/convert.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from inspect import signature
|
||||||
|
from itertools import chain
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
|
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
||||||
|
from ..utils import logging
|
||||||
|
from .config import OnnxConfig
|
||||||
|
from .utils import flatten_output_collection_property
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# This is the minimal required version to support some ONNX Runtime features
|
||||||
|
ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
|
||||||
|
|
||||||
|
|
||||||
|
def check_onnxruntime_requirements(minimum_version: Version):
|
||||||
|
"""
|
||||||
|
Check onnxruntime is installed and if the installed version match is recent enough
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If onnxruntime is not installed or too old version is found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
|
# Parse the version of the installed onnxruntime
|
||||||
|
ort_version = parse(onnxruntime.__version__)
|
||||||
|
|
||||||
|
# We require 1.4.0 minimum
|
||||||
|
if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
|
||||||
|
raise ImportError(
|
||||||
|
f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
|
||||||
|
f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
|
||||||
|
f"Please update onnxruntime by running `pip install --upgrade onnxruntime`"
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"onnxruntime doesn't seem to be currently installed. "
|
||||||
|
"Please install the onnxruntime by running `pip install onnxruntime`"
|
||||||
|
" and relaunch the conversion."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def export(
|
||||||
|
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer:
|
||||||
|
model:
|
||||||
|
config:
|
||||||
|
opset:
|
||||||
|
output:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not is_torch_available():
|
||||||
|
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx import export
|
||||||
|
|
||||||
|
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
model.config.return_dict = True
|
||||||
|
|
||||||
|
# Check if we need to override certain configuration item
|
||||||
|
if config.values_override is not None:
|
||||||
|
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||||
|
for override_config_key, override_config_value in config.values_override.items():
|
||||||
|
logger.info(f"\t- {override_config_key} -> {override_config_value}")
|
||||||
|
setattr(model.config, override_config_key, override_config_value)
|
||||||
|
|
||||||
|
# Ensure inputs match
|
||||||
|
# TODO: Check when exporting QA we provide "is_pair=True"
|
||||||
|
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||||
|
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||||
|
onnx_outputs = list(config.outputs.keys())
|
||||||
|
|
||||||
|
if not inputs_match:
|
||||||
|
raise ValueError("Model and config inputs doesn't match")
|
||||||
|
|
||||||
|
# export can works with named args but the dict containing named args as to be last element of the args tuple
|
||||||
|
export(
|
||||||
|
model,
|
||||||
|
(model_inputs,),
|
||||||
|
f=output.as_posix(),
|
||||||
|
input_names=list(config.inputs.keys()),
|
||||||
|
output_names=onnx_outputs,
|
||||||
|
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||||
|
do_constant_folding=True,
|
||||||
|
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
||||||
|
enable_onnx_checker=True,
|
||||||
|
opset_version=opset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched_inputs, onnx_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def validate_model_outputs(
|
||||||
|
config: OnnxConfig,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
reference_model: Union[PreTrainedModel, TFPreTrainedModel],
|
||||||
|
onnx_model: Path,
|
||||||
|
onnx_named_outputs: List[str],
|
||||||
|
atol: float,
|
||||||
|
):
|
||||||
|
from onnxruntime import InferenceSession, SessionOptions
|
||||||
|
|
||||||
|
logger.info("Validating ONNX model...")
|
||||||
|
|
||||||
|
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||||
|
|
||||||
|
# Create ONNX Runtime session
|
||||||
|
options = SessionOptions()
|
||||||
|
session = InferenceSession(onnx_model.as_posix(), options)
|
||||||
|
|
||||||
|
# Compute outputs from the reference model
|
||||||
|
ref_outputs = reference_model(**reference_model_inputs)
|
||||||
|
ref_outputs_dict = {}
|
||||||
|
|
||||||
|
# We flatten potential collection of outputs (i.e. past_keys) to a flat structure
|
||||||
|
for name, value in ref_outputs.items():
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
value = flatten_output_collection_property(name, value)
|
||||||
|
ref_outputs_dict.update(value)
|
||||||
|
else:
|
||||||
|
ref_outputs_dict[name] = value
|
||||||
|
|
||||||
|
# We flatten potential collection of inputs (i.e. past_keys)
|
||||||
|
onnx_inputs = {}
|
||||||
|
for name, value in reference_model_inputs.items():
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
value = flatten_output_collection_property(name, value)
|
||||||
|
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
|
||||||
|
else:
|
||||||
|
onnx_inputs[name] = value.numpy()
|
||||||
|
|
||||||
|
# Compute outputs from the ONNX model
|
||||||
|
onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
|
||||||
|
|
||||||
|
# Check we have a subset of the keys into onnx_outputs against ref_outputs
|
||||||
|
ref_outputs_set, onnx_outputs_set = set(ref_outputs_dict.keys()), set(onnx_named_outputs)
|
||||||
|
if not onnx_outputs_set.issubset(ref_outputs_set):
|
||||||
|
logger.info(
|
||||||
|
f"\t-[x] ONNX model outputs' name {onnx_outputs_set} doesn't match reference model {ref_outputs_set}"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Outputs doesn't match between reference model and ONNX exported model: "
|
||||||
|
f"{onnx_outputs_set.difference(ref_outputs_set)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"\t-[✓] ONNX model outputs' name match reference model ({onnx_outputs_set}")
|
||||||
|
|
||||||
|
# Check the shape and values match
|
||||||
|
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||||
|
ref_value = ref_outputs_dict[name].numpy()
|
||||||
|
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||||
|
|
||||||
|
# Shape
|
||||||
|
if not ort_value.shape == ref_value.shape:
|
||||||
|
logger.info(f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}")
|
||||||
|
raise ValueError(
|
||||||
|
"Outputs shape doesn't match between reference model and ONNX exported model: "
|
||||||
|
f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"\t\t-[✓] {ort_value.shape} matchs {ref_value.shape}")
|
||||||
|
|
||||||
|
# Values
|
||||||
|
if not np.allclose(ref_value, ort_value, atol=atol):
|
||||||
|
logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
|
||||||
|
raise ValueError(
|
||||||
|
"Outputs values doesn't match between reference model and ONNX exported model: "
|
||||||
|
f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"\t\t-[✓] all values close (atol: {atol})")
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_model_and_config_inputs_match(
|
||||||
|
model: Union[PreTrainedModel, TFPreTrainedModel], model_inputs: Iterable[str]
|
||||||
|
) -> Tuple[bool, List[str]]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param model_inputs:
|
||||||
|
:param config_inputs:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
forward_parameters = signature(model.forward).parameters
|
||||||
|
model_inputs_set = set(model_inputs)
|
||||||
|
|
||||||
|
# We are fine if config_inputs has more keys than model_inputs
|
||||||
|
forward_inputs_set = set(forward_parameters.keys())
|
||||||
|
is_ok = model_inputs_set.issubset(forward_inputs_set)
|
||||||
|
|
||||||
|
# Make sure the input order match (VERY IMPORTANT !!!!)
|
||||||
|
matching_inputs = forward_inputs_set.intersection(model_inputs_set)
|
||||||
|
ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs]
|
||||||
|
return is_ok, ordered_inputs
|
||||||
82
src/transformers/onnx/utils.py
Normal file
82
src/transformers/onnx/utils.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ctypes import c_float, sizeof
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
|
|
||||||
|
class ParameterFormat(Enum):
|
||||||
|
Float = c_float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
"""
|
||||||
|
Number of byte required for this data type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer > 0
|
||||||
|
"""
|
||||||
|
return sizeof(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_effective_axis_dimension(dimension: int, fixed_dimension: int, num_token_to_add: int = 0) -> int:
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dimension:
|
||||||
|
fixed_dimension:
|
||||||
|
num_token_to_add:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
# < 0 is possible if using a dynamic axis
|
||||||
|
if dimension <= 0:
|
||||||
|
dimension = fixed_dimension
|
||||||
|
|
||||||
|
dimension -= num_token_to_add
|
||||||
|
return dimension
|
||||||
|
|
||||||
|
|
||||||
|
def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterFormat) -> int:
|
||||||
|
"""
|
||||||
|
Compute the size taken by all the parameters in the given the storage format when serializing the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_parameters: Number of parameters to be saved
|
||||||
|
dtype: The data format each parameter will be saved
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Size (in byte) taken to save all the parameters
|
||||||
|
"""
|
||||||
|
return num_parameters * dtype.size
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_output_collection_property(name: str, field: Iterable[Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Flatten any potential nested structure expanding the name of the field with the index of the element within the
|
||||||
|
structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the nested structure
|
||||||
|
field: The structure to, potentially, be flattened
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(Dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
|
||||||
@@ -33,6 +33,7 @@ from .file_utils import (
|
|||||||
is_datasets_available,
|
is_datasets_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
|
is_keras2onnx_available,
|
||||||
is_onnx_available,
|
is_onnx_available,
|
||||||
is_pandas_available,
|
is_pandas_available,
|
||||||
is_rjieba_available,
|
is_rjieba_available,
|
||||||
@@ -234,6 +235,13 @@ def require_rjieba(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_keras2onnx(test_case):
|
||||||
|
if not is_keras2onnx_available():
|
||||||
|
return unittest.skip("test requires keras2onnx")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_onnx(test_case):
|
def require_onnx(test_case):
|
||||||
if not is_onnx_available():
|
if not is_onnx_available():
|
||||||
return unittest.skip("test requires ONNX")(test_case)
|
return unittest.skip("test requires ONNX")(test_case)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from transformers.testing_utils import (
|
|||||||
_tf_gpu_memory_limit,
|
_tf_gpu_memory_limit,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_onnx,
|
require_keras2onnx,
|
||||||
require_tf,
|
require_tf,
|
||||||
slow,
|
slow,
|
||||||
tooslow,
|
tooslow,
|
||||||
@@ -325,7 +325,7 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
self.assertEqual(len(incompatible_ops), 0, incompatible_ops)
|
||||||
|
|
||||||
@require_onnx
|
@require_keras2onnx
|
||||||
@slow
|
@slow
|
||||||
def test_onnx_runtime_optimize(self):
|
def test_onnx_runtime_optimize(self):
|
||||||
if not self.test_onnx:
|
if not self.test_onnx:
|
||||||
|
|||||||
251
tests/test_onnx_v2.py
Normal file
251
tests/test_onnx_v2.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from unittest import TestCase
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from transformers import ( # LongformerConfig,
|
||||||
|
AlbertConfig,
|
||||||
|
AutoTokenizer,
|
||||||
|
BartConfig,
|
||||||
|
DistilBertConfig,
|
||||||
|
GPT2Config,
|
||||||
|
RobertaConfig,
|
||||||
|
T5Config,
|
||||||
|
XLMRobertaConfig,
|
||||||
|
is_torch_available,
|
||||||
|
)
|
||||||
|
from transformers.models.albert import AlbertOnnxConfig
|
||||||
|
from transformers.models.bart import BartOnnxConfig
|
||||||
|
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
|
||||||
|
from transformers.models.distilbert import DistilBertOnnxConfig
|
||||||
|
|
||||||
|
# from transformers.models.longformer import LongformerOnnxConfig
|
||||||
|
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||||
|
from transformers.models.roberta import RobertaOnnxConfig
|
||||||
|
from transformers.models.t5 import T5OnnxConfig
|
||||||
|
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
|
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
|
||||||
|
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||||
|
from transformers.onnx.utils import (
|
||||||
|
compute_effective_axis_dimension,
|
||||||
|
compute_serialized_parameters_size,
|
||||||
|
flatten_output_collection_property,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
|
@require_onnx
|
||||||
|
class OnnxUtilsTestCaseV2(TestCase):
|
||||||
|
"""
|
||||||
|
Cover all the utilities involved to export ONNX models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_compute_effective_axis_dimension(self):
|
||||||
|
"""
|
||||||
|
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
|
||||||
|
We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values
|
||||||
|
(> 1 to avoid ONNX squeezing the axis).
|
||||||
|
|
||||||
|
This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dynamic axis (batch, no token added by the tokenizer)
|
||||||
|
self.assertEqual(compute_effective_axis_dimension(-1, fixed_dimension=2, num_token_to_add=0), 2)
|
||||||
|
|
||||||
|
# Static axis (batch, no token added by the tokenizer)
|
||||||
|
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=2, num_token_to_add=0), 2)
|
||||||
|
|
||||||
|
# Dynamic axis (sequence, token added by the tokenizer 2 (no pair))
|
||||||
|
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
|
||||||
|
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=2), 6)
|
||||||
|
|
||||||
|
# Dynamic axis (sequence, token added by the tokenizer 3 (pair))
|
||||||
|
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
|
||||||
|
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
|
||||||
|
|
||||||
|
def test_compute_parameters_serialized_size(self):
|
||||||
|
"""
|
||||||
|
This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the
|
||||||
|
parameters for the specified parameter's dtype.
|
||||||
|
"""
|
||||||
|
self.assertEqual(compute_serialized_parameters_size(2, ParameterFormat.Float), 2 * ParameterFormat.Float.size)
|
||||||
|
|
||||||
|
def test_flatten_output_collection_property(self):
|
||||||
|
"""
|
||||||
|
This test ensures we correctly flatten nested collection such as the one we use when returning past_keys.
|
||||||
|
past_keys = Tuple[Tuple]
|
||||||
|
|
||||||
|
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
|
||||||
|
"""
|
||||||
|
self.assertEqual(
|
||||||
|
flatten_output_collection_property("past_key", [[0], [1], [2]]),
|
||||||
|
{
|
||||||
|
"past_key.0": 0,
|
||||||
|
"past_key.1": 1,
|
||||||
|
"past_key.2": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxConfigTestCaseV2(TestCase):
|
||||||
|
"""
|
||||||
|
Cover the test for models default.
|
||||||
|
|
||||||
|
Default means no specific features is being enabled on the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@patch.multiple(OnnxConfig, __abstractmethods__=set())
|
||||||
|
def test_use_external_data_format(self):
|
||||||
|
"""
|
||||||
|
External data format is required only if the serialized size of the parameters if bigger than 2Gb
|
||||||
|
"""
|
||||||
|
TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT
|
||||||
|
|
||||||
|
# No parameters
|
||||||
|
self.assertFalse(OnnxConfig.use_external_data_format(0))
|
||||||
|
|
||||||
|
# Some parameters
|
||||||
|
self.assertFalse(OnnxConfig.use_external_data_format(1))
|
||||||
|
|
||||||
|
# Almost 2Gb parameters
|
||||||
|
self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size))
|
||||||
|
|
||||||
|
# Exactly 2Gb parameters
|
||||||
|
self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT))
|
||||||
|
|
||||||
|
# More than 2Gb parameters
|
||||||
|
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||||
|
"""
|
||||||
|
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
|
||||||
|
"""
|
||||||
|
|
||||||
|
SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)}
|
||||||
|
|
||||||
|
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||||
|
def test_use_past(self):
|
||||||
|
"""
|
||||||
|
Ensure the use_past variable is correctly being set
|
||||||
|
"""
|
||||||
|
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||||
|
with self.subTest(name):
|
||||||
|
self.assertFalse(
|
||||||
|
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||||
|
def test_values_override(self):
|
||||||
|
"""
|
||||||
|
Ensure the use_past variable correctly set the `use_cache` value in model's configuration
|
||||||
|
"""
|
||||||
|
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||||
|
with self.subTest(name):
|
||||||
|
|
||||||
|
# without past
|
||||||
|
onnx_config_default = OnnxConfigWithPast.default(config())
|
||||||
|
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||||
|
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
||||||
|
self.assertFalse(
|
||||||
|
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
|
||||||
|
)
|
||||||
|
|
||||||
|
# with past
|
||||||
|
onnx_config_default = OnnxConfigWithPast.with_past(config())
|
||||||
|
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||||
|
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
|
||||||
|
self.assertTrue(
|
||||||
|
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import (
|
||||||
|
AlbertModel,
|
||||||
|
BartModel,
|
||||||
|
BertModel,
|
||||||
|
DistilBertModel,
|
||||||
|
GPT2Model,
|
||||||
|
RobertaModel,
|
||||||
|
T5Model,
|
||||||
|
XLMRobertaModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
PYTORCH_EXPORT_DEFAULT_MODELS = {
|
||||||
|
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
|
||||||
|
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||||
|
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
||||||
|
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
||||||
|
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||||
|
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||||
|
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||||
|
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||||
|
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||||
|
}
|
||||||
|
|
||||||
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
|
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
|
||||||
|
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
|
||||||
|
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxExportTestCaseV2(TestCase):
|
||||||
|
"""
|
||||||
|
Integration tests ensuring supported models are correctly exported
|
||||||
|
"""
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_pytorch_export_default(self):
|
||||||
|
from transformers.onnx import export
|
||||||
|
|
||||||
|
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
|
||||||
|
with self.subTest(name):
|
||||||
|
self.assertTrue(hasattr(onnx_config_class, "default"))
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
model = model_class(config_class())
|
||||||
|
onnx_config = onnx_config_class.default(model.config)
|
||||||
|
|
||||||
|
with NamedTemporaryFile("w") as output:
|
||||||
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
|
||||||
|
except ValueError as ve:
|
||||||
|
self.fail(f"{name} -> {ve}")
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
def test_pytorch_export_with_past(self):
|
||||||
|
from transformers.onnx import export
|
||||||
|
|
||||||
|
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
|
||||||
|
with self.subTest(name):
|
||||||
|
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
model = model_class(config_class())
|
||||||
|
onnx_config = onnx_config_class.with_past(model.config)
|
||||||
|
|
||||||
|
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
|
||||||
|
self.assertTrue(
|
||||||
|
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
|
||||||
|
)
|
||||||
|
|
||||||
|
with NamedTemporaryFile("w") as output:
|
||||||
|
output = Path(output.name)
|
||||||
|
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
|
||||||
|
except ValueError as ve:
|
||||||
|
self.fail(f"{name} -> {ve}")
|
||||||
Reference in New Issue
Block a user