Add Tensorflow handling of ONNX conversion (#13831)
* Add TensorFlow support for ONNX export * Change documentation to mention conversion with Tensorflow * Refactor export into export_pytorch and export_tensorflow * Check model's type instead of framework installation to choose between TF and Pytorch Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Alberto Bégué <alberto.begue@della.ai> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -62,10 +62,6 @@ Ready-made configurations include the following architectures:
|
||||
- XLM-RoBERTa
|
||||
- XLM-RoBERTa-XL
|
||||
|
||||
The ONNX conversion is supported for the PyTorch versions of the models. If you
|
||||
would like to be able to convert a TensorFlow model, please let us know by
|
||||
opening an issue.
|
||||
|
||||
In the next two sections, we'll show you how to:
|
||||
|
||||
* Export a supported model using the `transformers.onnx` package.
|
||||
@@ -150,6 +146,8 @@ DistilBERT we have:
|
||||
["last_hidden_state"]
|
||||
```
|
||||
|
||||
The approach is similar for TensorFlow models.
|
||||
|
||||
### Selecting features for different model topologies
|
||||
|
||||
Each ready-made configuration comes with a set of _features_ that enable you to
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
||||
from transformers.file_utils import is_torch_onnx_dict_inputs_support_available
|
||||
from transformers.file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available
|
||||
from transformers.onnx.config import OnnxConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
@@ -62,90 +62,190 @@ def check_onnxruntime_requirements(minimum_version: Version):
|
||||
)
|
||||
|
||||
|
||||
def export(
|
||||
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
|
||||
def export_pytorch(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: PreTrainedModel,
|
||||
config: OnnxConfig,
|
||||
opset: int,
|
||||
output: Path,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
|
||||
Export a PyTorch model to an ONNX Intermediate Representation (IR)
|
||||
|
||||
Args:
|
||||
tokenizer:
|
||||
model:
|
||||
config:
|
||||
opset:
|
||||
output:
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer used for encoding the data.
|
||||
model ([`PreTrainedModel`]):
|
||||
The model to export.
|
||||
config ([`~onnx.config.OnnxConfig`]):
|
||||
The ONNX configuration associated with the exported model.
|
||||
opset (`int`):
|
||||
The version of the ONNX operator set to use.
|
||||
output (`Path`):
|
||||
Directory to store the exported ONNX model.
|
||||
|
||||
Returns:
|
||||
|
||||
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||
the ONNX configuration.
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||
if issubclass(type(model), PreTrainedModel):
|
||||
import torch
|
||||
from torch.onnx import export as onnx_export
|
||||
|
||||
import torch
|
||||
from torch.onnx import export
|
||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||
with torch.no_grad():
|
||||
model.config.return_dict = True
|
||||
model.eval()
|
||||
|
||||
from ..file_utils import torch_version
|
||||
# Check if we need to override certain configuration item
|
||||
if config.values_override is not None:
|
||||
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||
for override_config_key, override_config_value in config.values_override.items():
|
||||
logger.info(f"\t- {override_config_key} -> {override_config_value}")
|
||||
setattr(model.config, override_config_key, override_config_value)
|
||||
|
||||
if not is_torch_onnx_dict_inputs_support_available():
|
||||
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
||||
# Ensure inputs match
|
||||
# TODO: Check when exporting QA we provide "is_pair=True"
|
||||
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||
onnx_outputs = list(config.outputs.keys())
|
||||
|
||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||
with torch.no_grad():
|
||||
model.config.return_dict = True
|
||||
model.eval()
|
||||
if not inputs_match:
|
||||
raise ValueError("Model and config inputs doesn't match")
|
||||
|
||||
# Check if we need to override certain configuration item
|
||||
if config.values_override is not None:
|
||||
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||
for override_config_key, override_config_value in config.values_override.items():
|
||||
logger.info(f"\t- {override_config_key} -> {override_config_value}")
|
||||
setattr(model.config, override_config_key, override_config_value)
|
||||
config.patch_ops()
|
||||
|
||||
# Ensure inputs match
|
||||
# TODO: Check when exporting QA we provide "is_pair=True"
|
||||
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||
onnx_outputs = list(config.outputs.keys())
|
||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||
# so we check the torch version for backwards compatibility
|
||||
if parse(torch.__version__) <= parse("1.10.99"):
|
||||
# export can work with named args but the dict containing named args
|
||||
# has to be the last element of the args tuple.
|
||||
onnx_export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
onnx_export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
if not inputs_match:
|
||||
raise ValueError("Model and config inputs doesn't match")
|
||||
|
||||
config.patch_ops()
|
||||
|
||||
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
|
||||
# so we check the torch version for backwards compatibility
|
||||
if parse(torch.__version__) <= parse("1.10.99"):
|
||||
# export can work with named args but the dict containing named args
|
||||
# has to be the last element of the args tuple.
|
||||
export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=config.use_external_data_format(model.num_parameters()),
|
||||
enable_onnx_checker=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
else:
|
||||
export(
|
||||
model,
|
||||
(model_inputs,),
|
||||
f=output.as_posix(),
|
||||
input_names=list(config.inputs.keys()),
|
||||
output_names=onnx_outputs,
|
||||
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
config.restore_ops()
|
||||
config.restore_ops()
|
||||
|
||||
return matched_inputs, onnx_outputs
|
||||
|
||||
|
||||
def export_tensorflow(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: TFPreTrainedModel,
|
||||
config: OnnxConfig,
|
||||
opset: int,
|
||||
output: Path,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Export a TensorFlow model to an ONNX Intermediate Representation (IR)
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer used for encoding the data.
|
||||
model ([`TFPreTrainedModel`]):
|
||||
The model to export.
|
||||
config ([`~onnx.config.OnnxConfig`]):
|
||||
The ONNX configuration associated with the exported model.
|
||||
opset (`int`):
|
||||
The version of the ONNX operator set to use.
|
||||
output (`Path`):
|
||||
Directory to store the exported ONNX model.
|
||||
|
||||
Returns:
|
||||
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||
the ONNX configuration.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
|
||||
import onnx
|
||||
import tf2onnx
|
||||
|
||||
model.config.return_dict = True
|
||||
|
||||
# Check if we need to override certain configuration item
|
||||
if config.values_override is not None:
|
||||
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
|
||||
for override_config_key, override_config_value in config.values_override.items():
|
||||
logger.info(f"\t- {override_config_key} -> {override_config_value}")
|
||||
setattr(model.config, override_config_key, override_config_value)
|
||||
|
||||
# Ensure inputs match
|
||||
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
|
||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||
onnx_outputs = list(config.outputs.keys())
|
||||
|
||||
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
|
||||
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
|
||||
onnx.save(onnx_model, output.as_posix())
|
||||
config.restore_ops()
|
||||
|
||||
return matched_inputs, onnx_outputs
|
||||
|
||||
|
||||
def export(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model: Union[PreTrainedModel, TFPreTrainedModel],
|
||||
config: OnnxConfig,
|
||||
opset: int,
|
||||
output: Path,
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR)
|
||||
|
||||
Args:
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
The tokenizer used for encoding the data.
|
||||
model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
|
||||
The model to export.
|
||||
config ([`~onnx.config.OnnxConfig`]):
|
||||
The ONNX configuration associated with the exported model.
|
||||
opset (`int`):
|
||||
The version of the ONNX operator set to use.
|
||||
output (`Path`):
|
||||
Directory to store the exported ONNX model.
|
||||
|
||||
Returns:
|
||||
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||
the ONNX configuration.
|
||||
"""
|
||||
if not (is_torch_available() or is_tf_available()):
|
||||
raise ImportError(
|
||||
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
|
||||
"Please install torch or tensorflow first."
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.file_utils import torch_version
|
||||
|
||||
if not is_torch_onnx_dict_inputs_support_available():
|
||||
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
||||
|
||||
if is_torch_available() and issubclass(type(model), PreTrainedModel):
|
||||
return export_pytorch(tokenizer, model, config, opset, output)
|
||||
elif is_tf_available() and issubclass(type(model), TFPreTrainedModel):
|
||||
return export_tensorflow(tokenizer, model, config, opset, output)
|
||||
|
||||
|
||||
def validate_model_outputs(
|
||||
config: OnnxConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
@@ -160,7 +260,10 @@ def validate_model_outputs(
|
||||
|
||||
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||
# dynamic input shapes.
|
||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||
if issubclass(type(reference_model), PreTrainedModel):
|
||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
|
||||
else:
|
||||
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
|
||||
|
||||
# Create ONNX Runtime session
|
||||
options = SessionOptions()
|
||||
@@ -210,7 +313,10 @@ def validate_model_outputs(
|
||||
|
||||
# Check the shape and values match
|
||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||
if issubclass(type(reference_model), PreTrainedModel):
|
||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||
else:
|
||||
ref_value = ref_outputs_dict[name].numpy()
|
||||
logger.info(f'\t- Validating ONNX Model output "{name}":')
|
||||
|
||||
# Shape
|
||||
@@ -241,7 +347,10 @@ def ensure_model_and_config_inputs_match(
|
||||
|
||||
:param model_inputs: :param config_inputs: :return:
|
||||
"""
|
||||
forward_parameters = signature(model.forward).parameters
|
||||
if issubclass(type(model), PreTrainedModel):
|
||||
forward_parameters = signature(model.forward).parameters
|
||||
else:
|
||||
forward_parameters = signature(model.call).parameters
|
||||
model_inputs_set = set(model_inputs)
|
||||
|
||||
# We are fine if config_inputs has more keys than model_inputs
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from .. import PretrainedConfig, is_torch_available
|
||||
from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
|
||||
from ..models.albert import AlbertOnnxConfig
|
||||
from ..models.bart import BartOnnxConfig
|
||||
from ..models.bert import BertOnnxConfig
|
||||
@@ -24,7 +24,6 @@ from .config import OnnxConfig
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.auto import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
@@ -35,9 +34,20 @@ if is_torch_available():
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
)
|
||||
elif is_tf_available():
|
||||
from transformers.models.auto import (
|
||||
TFAutoModel,
|
||||
TFAutoModelForCausalLM,
|
||||
TFAutoModelForMaskedLM,
|
||||
TFAutoModelForMultipleChoice,
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForTokenClassification,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The ONNX export features are only supported for PyTorch, you will not be able to export models without it."
|
||||
"The ONNX export features are only supported for PyTorch or TensorFlow. You will not be able to export models without one of these libraries installed."
|
||||
)
|
||||
|
||||
|
||||
@@ -80,6 +90,17 @@ class FeaturesManager:
|
||||
"multiple-choice": AutoModelForMultipleChoice,
|
||||
"question-answering": AutoModelForQuestionAnswering,
|
||||
}
|
||||
elif is_tf_available():
|
||||
_TASKS_TO_AUTOMODELS = {
|
||||
"default": TFAutoModel,
|
||||
"masked-lm": TFAutoModelForMaskedLM,
|
||||
"causal-lm": TFAutoModelForCausalLM,
|
||||
"seq2seq-lm": TFAutoModelForSeq2SeqLM,
|
||||
"sequence-classification": TFAutoModelForSequenceClassification,
|
||||
"token-classification": TFAutoModelForTokenClassification,
|
||||
"multiple-choice": TFAutoModelForMultipleChoice,
|
||||
"question-answering": TFAutoModelForQuestionAnswering,
|
||||
}
|
||||
else:
|
||||
_TASKS_TO_AUTOMODELS = {}
|
||||
|
||||
@@ -270,7 +291,7 @@ class FeaturesManager:
|
||||
)
|
||||
return FeaturesManager._TASKS_TO_AUTOMODELS[task]
|
||||
|
||||
def get_model_from_feature(feature: str, model: str) -> PreTrainedModel:
|
||||
def get_model_from_feature(feature: str, model: str) -> Union[PreTrainedModel, TFPreTrainedModel]:
|
||||
"""
|
||||
Attempt to retrieve a model from a model's name and the feature to be enabled.
|
||||
|
||||
@@ -286,7 +307,9 @@ class FeaturesManager:
|
||||
return model_class.from_pretrained(model)
|
||||
|
||||
@staticmethod
|
||||
def check_supported_model_or_raise(model: PreTrainedModel, feature: str = "default") -> Tuple[str, Callable]:
|
||||
def check_supported_model_or_raise(
|
||||
model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
|
||||
) -> Tuple[str, Callable]:
|
||||
"""
|
||||
Check whether or not the model has the requested features.
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
||||
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
@@ -15,11 +15,11 @@ from transformers.onnx import (
|
||||
from transformers.onnx.config import OnnxConfigWithPast
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
if is_torch_available() or is_tf_available():
|
||||
from transformers.onnx.features import FeaturesManager
|
||||
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
|
||||
|
||||
|
||||
@require_onnx
|
||||
@@ -192,19 +192,44 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||
("albert", "hf-internal-testing/tiny-albert"),
|
||||
("bert", "bert-base-cased"),
|
||||
("ibert", "kssteven/ibert-roberta-base"),
|
||||
("camembert", "camembert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
("roberta", "roberta-base"),
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("bart", "facebook/bart-base"),
|
||||
("mbart", "sshleifer/tiny-mbart"),
|
||||
("t5", "t5-small"),
|
||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||
}
|
||||
|
||||
|
||||
def _get_models_to_test(export_models_list):
|
||||
models_to_test = []
|
||||
if not is_torch_available():
|
||||
# Returning some dummy test that should not be ever called because of the @require_torch decorator.
|
||||
if is_torch_available() or is_tf_available():
|
||||
for (name, model) in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return sorted(models_to_test)
|
||||
else:
|
||||
# Returning some dummy test that should not be ever called because of the @require_torch / @require_tf
|
||||
# decorators.
|
||||
# The reason for not returning an empty list is because parameterized.expand complains when it's empty.
|
||||
return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
|
||||
for (name, model) in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return sorted(models_to_test)
|
||||
|
||||
|
||||
class OnnxExportTestCaseV2(TestCase):
|
||||
@@ -212,7 +237,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
from transformers.onnx import export
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@@ -246,13 +271,13 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@@ -260,4 +285,24 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_pytorch_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@@ -211,7 +211,7 @@ def check_onnx_model_list(overwrite=False):
|
||||
current_list, start_index, end_index, lines = _find_text_in_file(
|
||||
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
|
||||
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
|
||||
end_prompt="The ONNX conversion is supported for the PyTorch versions of the models.",
|
||||
end_prompt="In the next two sections, we'll show you how to:",
|
||||
)
|
||||
new_list = get_onnx_model_list()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user