Add Tensorflow handling of ONNX conversion (#13831)
* Add TensorFlow support for ONNX export * Change documentation to mention conversion with Tensorflow * Refactor export into export_pytorch and export_tensorflow * Check model's type instead of framework installation to choose between TF and Pytorch Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Alberto Bégué <alberto.begue@della.ai> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -62,10 +62,6 @@ Ready-made configurations include the following architectures:
|
|||||||
- XLM-RoBERTa
|
- XLM-RoBERTa
|
||||||
- XLM-RoBERTa-XL
|
- XLM-RoBERTa-XL
|
||||||
|
|
||||||
The ONNX conversion is supported for the PyTorch versions of the models. If you
|
|
||||||
would like to be able to convert a TensorFlow model, please let us know by
|
|
||||||
opening an issue.
|
|
||||||
|
|
||||||
In the next two sections, we'll show you how to:
|
In the next two sections, we'll show you how to:
|
||||||
|
|
||||||
* Export a supported model using the `transformers.onnx` package.
|
* Export a supported model using the `transformers.onnx` package.
|
||||||
@@ -150,6 +146,8 @@ DistilBERT we have:
|
|||||||
["last_hidden_state"]
|
["last_hidden_state"]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The approach is similar for TensorFlow models.
|
||||||
|
|
||||||
### Selecting features for different model topologies
|
### Selecting features for different model topologies
|
||||||
|
|
||||||
Each ready-made configuration comes with a set of _features_ that enable you to
|
Each ready-made configuration comes with a set of _features_ that enable you to
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import numpy as np
|
|||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
||||||
from transformers.file_utils import is_torch_onnx_dict_inputs_support_available
|
from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available
|
||||||
from transformers.onnx.config import OnnxConfig
|
from transformers.onnx.config import OnnxConfig
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
@@ -62,32 +62,35 @@ def check_onnxruntime_requirements(minimum_version: Version):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def export(
|
def export_pytorch(
|
||||||
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model: PreTrainedModel,
|
||||||
|
config: OnnxConfig,
|
||||||
|
opset: int,
|
||||||
|
output: Path,
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> Tuple[List[str], List[str]]:
|
||||||
"""
|
"""
|
||||||
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
|
Export a PyTorch model to an ONNX Intermediate Representation (IR)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokenizer:
|
tokenizer ([`PreTrainedTokenizer`]):
|
||||||
model:
|
The tokenizer used for encoding the data.
|
||||||
config:
|
model ([`PreTrainedModel`]):
|
||||||
opset:
|
The model to export.
|
||||||
output:
|
config ([`~onnx.config.OnnxConfig`]):
|
||||||
|
The ONNX configuration associated with the exported model.
|
||||||
|
opset (`int`):
|
||||||
|
The version of the ONNX operator set to use.
|
||||||
|
output (`Path`):
|
||||||
|
Directory to store the exported ONNX model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||||
|
the ONNX configuration.
|
||||||
"""
|
"""
|
||||||
if not is_torch_available():
|
if issubclass(type(model), PreTrainedModel):
|
||||||
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export as onnx_export
|
||||||
|
|
||||||
from ..file_utils import torch_version
|
|
||||||
|
|
||||||
if not is_torch_onnx_dict_inputs_support_available():
|
|
||||||
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
|
||||||
|
|
||||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -117,7 +120,7 @@ def export(
|
|||||||
if parse(torch.__version__) <= parse("1.10.99"):
|
if parse(torch.__version__) <= parse("1.10.99"):
|
||||||
# export can work with named args but the dict containing named args
|
# export can work with named args but the dict containing named args
|
||||||
# has to be the last element of the args tuple.
|
# has to be the last element of the args tuple.
|
||||||
export(
|
onnx_export(
|
||||||
model,
|
model,
|
||||||
(model_inputs,),
|
(model_inputs,),
|
||||||
f=output.as_posix(),
|
f=output.as_posix(),
|
||||||
@@ -130,7 +133,7 @@ def export(
|
|||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
export(
|
onnx_export(
|
||||||
model,
|
model,
|
||||||
(model_inputs,),
|
(model_inputs,),
|
||||||
f=output.as_posix(),
|
f=output.as_posix(),
|
||||||
@@ -146,6 +149,103 @@ def export(
|
|||||||
return matched_inputs, onnx_outputs
|
return matched_inputs, onnx_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def export_tensorflow(
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model: TFPreTrainedModel,
|
||||||
|
config: OnnxConfig,
|
||||||
|
opset: int,
|
||||||
|
output: Path,
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Export a TensorFlow model to an ONNX Intermediate Representation (IR)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer ([`PreTrainedTokenizer`]):
|
||||||
|
The tokenizer used for encoding the data.
|
||||||
|
model ([`TFPreTrainedModel`]):
|
||||||
|
The model to export.
|
||||||
|
config ([`~onnx.config.OnnxConfig`]):
|
||||||
|
The ONNX configuration associated with the exported model.
|
||||||
|
opset (`int`):
|
||||||
|
The version of the ONNX operator set to use.
|
||||||
|
output (`Path`):
|
||||||
|
Directory to store the exported ONNX model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||||
|
the ONNX configuration.
|
||||||
|
"""
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
import tf2onnx
|
||||||
|
|
||||||
|
model.config.return_dict = True
|
||||||
|
|
||||||
|
# Check if we need to override certain configuration item
|
||||||
|
if config.values_override is not None:
|
||||||
|
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||||
|
for override_config_key, override_config_value in config.values_override.items():
|
||||||
|
logger.info(f"\t- {override_config_key} -> {override_config_value}")
|
||||||
|
setattr(model.config, override_config_key, override_config_value)
|
||||||
|
|
||||||
|
# Ensure inputs match
|
||||||
|
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
|
||||||
|
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||||
|
onnx_outputs = list(config.outputs.keys())
|
||||||
|
|
||||||
|
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
|
||||||
|
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
|
||||||
|
onnx.save(onnx_model, output.as_posix())
|
||||||
|
config.restore_ops()
|
||||||
|
|
||||||
|
return matched_inputs, onnx_outputs
|
||||||
|
|
||||||
|
|
||||||
|
def export(
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
model: Union[PreTrainedModel, TFPreTrainedModel],
|
||||||
|
config: OnnxConfig,
|
||||||
|
opset: int,
|
||||||
|
output: Path,
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer ([`PreTrainedTokenizer`]):
|
||||||
|
The tokenizer used for encoding the data.
|
||||||
|
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||||
|
The model to export.
|
||||||
|
config ([`~onnx.config.OnnxConfig`]):
|
||||||
|
The ONNX configuration associated with the exported model.
|
||||||
|
opset (`int`):
|
||||||
|
The version of the ONNX operator set to use.
|
||||||
|
output (`Path`):
|
||||||
|
Directory to store the exported ONNX model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||||
|
the ONNX configuration.
|
||||||
|
"""
|
||||||
|
if not (is_torch_available() or is_tf_available()):
|
||||||
|
raise ImportError(
|
||||||
|
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
|
||||||
|
"Please install torch or tensorflow first."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.file_utils import torch_version
|
||||||
|
|
||||||
|
if not is_torch_onnx_dict_inputs_support_available():
|
||||||
|
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
||||||
|
|
||||||
|
if is_torch_available() and issubclass(type(model), PreTrainedModel):
|
||||||
|
return export_pytorch(tokenizer, model, config, opset, output)
|
||||||
|
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
|
||||||
|
return export_tensorflow(tokenizer, model, config, opset, output)
|
||||||
|
|
||||||
|
|
||||||
def validate_model_outputs(
|
def validate_model_outputs(
|
||||||
config: OnnxConfig,
|
config: OnnxConfig,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
@@ -160,7 +260,10 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||||
# dynamic input shapes.
|
# dynamic input shapes.
|
||||||
|
if issubclass(type(reference_model), PreTrainedModel):
|
||||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||||
|
else:
|
||||||
|
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
|
||||||
|
|
||||||
# Create ONNX Runtime session
|
# Create ONNX Runtime session
|
||||||
options = SessionOptions()
|
options = SessionOptions()
|
||||||
@@ -210,7 +313,10 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
# Check the shape and values match
|
# Check the shape and values match
|
||||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||||
|
if issubclass(type(reference_model), PreTrainedModel):
|
||||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||||
|
else:
|
||||||
|
ref_value = ref_outputs_dict[name].numpy()
|
||||||
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||||
|
|
||||||
# Shape
|
# Shape
|
||||||
@@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match(
|
|||||||
|
|
||||||
:param model_inputs: :param config_inputs: :return:
|
:param model_inputs: :param config_inputs: :return:
|
||||||
"""
|
"""
|
||||||
|
if issubclass(type(model), PreTrainedModel):
|
||||||
forward_parameters = signature(model.forward).parameters
|
forward_parameters = signature(model.forward).parameters
|
||||||
|
else:
|
||||||
|
forward_parameters = signature(model.call).parameters
|
||||||
model_inputs_set = set(model_inputs)
|
model_inputs_set = set(model_inputs)
|
||||||
|
|
||||||
# We are fine if config_inputs has more keys than model_inputs
|
# We are fine if config_inputs has more keys than model_inputs
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from functools import partial, reduce
|
from functools import partial, reduce
|
||||||
from typing import Callable, Dict, Optional, Tuple, Type
|
from typing import Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
from .. import PretrainedConfig, is_torch_available
|
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
|
||||||
from ..models.albert import AlbertOnnxConfig
|
from ..models.albert import AlbertOnnxConfig
|
||||||
from ..models.bart import BartOnnxConfig
|
from ..models.bart import BartOnnxConfig
|
||||||
from ..models.bert import BertOnnxConfig
|
from ..models.bert import BertOnnxConfig
|
||||||
@@ -24,7 +24,6 @@ from .config import OnnxConfig
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import PreTrainedModel
|
|
||||||
from transformers.models.auto import (
|
from transformers.models.auto import (
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -35,9 +34,20 @@ if is_torch_available():
|
|||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
)
|
)
|
||||||
|
elif is_tf_available():
|
||||||
|
from transformers.models.auto import (
|
||||||
|
TFAutoModel,
|
||||||
|
TFAutoModelForCausalLM,
|
||||||
|
TFAutoModelForMaskedLM,
|
||||||
|
TFAutoModelForMultipleChoice,
|
||||||
|
TFAutoModelForQuestionAnswering,
|
||||||
|
TFAutoModelForSeq2SeqLM,
|
||||||
|
TFAutoModelForSequenceClassification,
|
||||||
|
TFAutoModelForTokenClassification,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The ONNX export features are only supported for PyTorch, you will not be able to export models without it."
|
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -80,6 +90,17 @@ class FeaturesManager:
|
|||||||
"multiple-choice": AutoModelForMultipleChoice,
|
"multiple-choice": AutoModelForMultipleChoice,
|
||||||
"question-answering": AutoModelForQuestionAnswering,
|
"question-answering": AutoModelForQuestionAnswering,
|
||||||
}
|
}
|
||||||
|
elif is_tf_available():
|
||||||
|
_TASKS_TO_AUTOMODELS = {
|
||||||
|
"default": TFAutoModel,
|
||||||
|
"masked-lm": TFAutoModelForMaskedLM,
|
||||||
|
"causal-lm": TFAutoModelForCausalLM,
|
||||||
|
"seq2seq-lm": TFAutoModelForSeq2SeqLM,
|
||||||
|
"sequence-classification": TFAutoModelForSequenceClassification,
|
||||||
|
"token-classification": TFAutoModelForTokenClassification,
|
||||||
|
"multiple-choice": TFAutoModelForMultipleChoice,
|
||||||
|
"question-answering": TFAutoModelForQuestionAnswering,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
_TASKS_TO_AUTOMODELS = {}
|
_TASKS_TO_AUTOMODELS = {}
|
||||||
|
|
||||||
@@ -270,7 +291,7 @@ class FeaturesManager:
|
|||||||
)
|
)
|
||||||
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
|
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
|
||||||
|
|
||||||
def get_model_from_feature(feature: str, model: str) -> PreTrainedModel:
|
def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
|
||||||
"""
|
"""
|
||||||
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||||
|
|
||||||
@@ -286,7 +307,9 @@ class FeaturesManager:
|
|||||||
return model_class.from_pretrained(model)
|
return model_class.from_pretrained(model)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
|
def check_supported_model_or_raise(
|
||||||
|
model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
|
||||||
|
) -> Tuple[str, Callable]:
|
||||||
"""
|
"""
|
||||||
Check whether or not the model has the requested features.
|
Check whether or not the model has the requested features.
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from unittest import TestCase
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
|
||||||
from transformers.onnx import (
|
from transformers.onnx import (
|
||||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||||
OnnxConfig,
|
OnnxConfig,
|
||||||
@@ -15,11 +15,11 @@ from transformers.onnx import (
|
|||||||
from transformers.onnx.config import OnnxConfigWithPast
|
from transformers.onnx.config import OnnxConfigWithPast
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available() or is_tf_available():
|
||||||
from transformers.onnx.features import FeaturesManager
|
from transformers.onnx.features import FeaturesManager
|
||||||
|
|
||||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
@require_onnx
|
@require_onnx
|
||||||
@@ -192,19 +192,44 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||||
|
("albert", "hf-internal-testing/tiny-albert"),
|
||||||
|
("bert", "bert-base-cased"),
|
||||||
|
("ibert", "kssteven/ibert-roberta-base"),
|
||||||
|
("camembert", "camembert-base"),
|
||||||
|
("distilbert", "distilbert-base-cased"),
|
||||||
|
("roberta", "roberta-base"),
|
||||||
|
("xlm-roberta", "xlm-roberta-base"),
|
||||||
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
|
}
|
||||||
|
|
||||||
|
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
|
||||||
|
("gpt2", "gpt2"),
|
||||||
|
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||||
|
}
|
||||||
|
|
||||||
|
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||||
|
("bart", "facebook/bart-base"),
|
||||||
|
("mbart", "sshleifer/tiny-mbart"),
|
||||||
|
("t5", "t5-small"),
|
||||||
|
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_models_to_test(export_models_list):
|
def _get_models_to_test(export_models_list):
|
||||||
models_to_test = []
|
models_to_test = []
|
||||||
if not is_torch_available():
|
if is_torch_available() or is_tf_available():
|
||||||
# Returning some dummy test that should not be ever called because of the @require_torch decorator.
|
|
||||||
# The reason for not returning an empty list is because parameterized.expand complains when it's empty.
|
|
||||||
return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
|
|
||||||
for (name, model) in export_models_list:
|
for (name, model) in export_models_list:
|
||||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||||
name
|
name
|
||||||
).items():
|
).items():
|
||||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||||
return sorted(models_to_test)
|
return sorted(models_to_test)
|
||||||
|
else:
|
||||||
|
# Returning some dummy test that should not be ever called because of the @require_torch / @require_tf
|
||||||
|
# decorators.
|
||||||
|
# The reason for not returning an empty list is because parameterized.expand complains when it's empty.
|
||||||
|
return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
|
||||||
|
|
||||||
|
|
||||||
class OnnxExportTestCaseV2(TestCase):
|
class OnnxExportTestCaseV2(TestCase):
|
||||||
@@ -212,7 +237,7 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
Integration tests ensuring supported models are correctly exported
|
Integration tests ensuring supported models are correctly exported
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
from transformers.onnx import export
|
from transformers.onnx import export
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
@@ -246,13 +271,13 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||||
@slow
|
@slow
|
||||||
@@ -260,4 +285,24 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
def test_pytorch_export_seq2seq_with_past(
|
def test_pytorch_export_seq2seq_with_past(
|
||||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||||
):
|
):
|
||||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
|
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
|
||||||
|
@slow
|
||||||
|
@require_tf
|
||||||
|
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
|
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
|
||||||
|
@slow
|
||||||
|
@require_tf
|
||||||
|
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
|
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||||
|
@slow
|
||||||
|
@require_tf
|
||||||
|
def test_tensorflow_export_seq2seq_with_past(
|
||||||
|
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||||
|
):
|
||||||
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ def check_onnx_model_list(overwrite=False):
|
|||||||
current_list, start_index, end_index, lines = _find_text_in_file(
|
current_list, start_index, end_index, lines = _find_text_in_file(
|
||||||
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
|
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
|
||||||
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
|
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
|
||||||
end_prompt="The ONNX conversion is supported for the PyTorch versions of the models.",
|
end_prompt="In the next two sections, we'll show you how to:",
|
||||||
)
|
)
|
||||||
new_list = get_onnx_model_list()
|
new_list = get_onnx_model_list()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user