Fix duplicate arguments passed to dummy inputs in ONNX export (#16045)
* Fix duplicate arguments passed to dummy inputs in ONNX export * Fix M2M100 ONNX config * Ensure we check PreTrained model only if torch is available * Remove TensorFlow tests for models without PyTorch parity
This commit is contained in:
@@ -198,13 +198,13 @@ class M2M100OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
|||||||
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
||||||
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
batch_size = compute_effective_axis_dimension(
|
batch_size = compute_effective_axis_dimension(
|
||||||
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
|
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||||
seq_length = compute_effective_axis_dimension(
|
seq_length = compute_effective_axis_dimension(
|
||||||
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
|
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate dummy inputs according to compute batch and sequence
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import numpy as np
|
|||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
from ..file_utils import TensorType, is_tf_available, is_torch_available, is_torch_onnx_dict_inputs_support_available
|
from ..file_utils import TensorType, is_tf_available, is_torch_available, is_torch_onnx_dict_inputs_support_available
|
||||||
|
from ..tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .config import OnnxConfig
|
from .config import OnnxConfig
|
||||||
|
|
||||||
@@ -100,11 +101,17 @@ def export_pytorch(
|
|||||||
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
`Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from
|
||||||
the ONNX configuration.
|
the ONNX configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||||
|
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
|
||||||
|
preprocessor = tokenizer
|
||||||
|
|
||||||
if issubclass(type(model), PreTrainedModel):
|
if issubclass(type(model), PreTrainedModel):
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export as onnx_export
|
from torch.onnx import export as onnx_export
|
||||||
@@ -123,9 +130,7 @@ def export_pytorch(
|
|||||||
|
|
||||||
# Ensure inputs match
|
# Ensure inputs match
|
||||||
# TODO: Check when exporting QA we provide "is_pair=True"
|
# TODO: Check when exporting QA we provide "is_pair=True"
|
||||||
model_inputs = config.generate_dummy_inputs(
|
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
|
||||||
preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH
|
|
||||||
)
|
|
||||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||||
onnx_outputs = list(config.outputs.keys())
|
onnx_outputs = list(config.outputs.keys())
|
||||||
|
|
||||||
@@ -213,11 +218,15 @@ def export_tensorflow(
|
|||||||
import onnx
|
import onnx
|
||||||
import tf2onnx
|
import tf2onnx
|
||||||
|
|
||||||
|
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||||
|
raise ValueError("You cannot provide both a tokenizer and preprocessor to export the model.")
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
|
||||||
|
preprocessor = tokenizer
|
||||||
|
|
||||||
model.config.return_dict = True
|
model.config.return_dict = True
|
||||||
|
|
||||||
@@ -229,7 +238,7 @@ def export_tensorflow(
|
|||||||
setattr(model.config, override_config_key, override_config_value)
|
setattr(model.config, override_config_key, override_config_value)
|
||||||
|
|
||||||
# Ensure inputs match
|
# Ensure inputs match
|
||||||
model_inputs = config.generate_dummy_inputs(preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW)
|
model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW)
|
||||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||||
onnx_outputs = list(config.outputs.keys())
|
onnx_outputs = list(config.outputs.keys())
|
||||||
|
|
||||||
@@ -273,11 +282,16 @@ def export(
|
|||||||
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
|
"Cannot convert because neither PyTorch nor TensorFlow are not installed. "
|
||||||
"Please install torch or tensorflow first."
|
"Please install torch or tensorflow first."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||||
|
raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.")
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
|
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
|
||||||
|
preprocessor = tokenizer
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from ..file_utils import torch_version
|
from ..file_utils import torch_version
|
||||||
@@ -309,16 +323,22 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
logger.info("Validating ONNX model...")
|
logger.info("Validating ONNX model...")
|
||||||
|
|
||||||
|
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
|
||||||
|
raise ValueError("You cannot provide both a tokenizer and a preprocessor to validatethe model outputs.")
|
||||||
|
if tokenizer is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use `preprocessor` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
|
||||||
|
preprocessor = tokenizer
|
||||||
|
|
||||||
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||||
# dynamic input shapes.
|
# dynamic input shapes.
|
||||||
if issubclass(type(reference_model), PreTrainedModel):
|
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
|
||||||
reference_model_inputs = config.generate_dummy_inputs(
|
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
|
||||||
preprocessor, tokenizer=tokenizer, framework=TensorType.PYTORCH
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
reference_model_inputs = config.generate_dummy_inputs(
|
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW)
|
||||||
preprocessor, tokenizer=tokenizer, framework=TensorType.TENSORFLOW
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create ONNX Runtime session
|
# Create ONNX Runtime session
|
||||||
options = SessionOptions()
|
options = SessionOptions()
|
||||||
@@ -368,7 +388,7 @@ def validate_model_outputs(
|
|||||||
|
|
||||||
# Check the shape and values match
|
# Check the shape and values match
|
||||||
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
|
||||||
if issubclass(type(reference_model), PreTrainedModel):
|
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
|
||||||
ref_value = ref_outputs_dict[name].detach().numpy()
|
ref_value = ref_outputs_dict[name].detach().numpy()
|
||||||
else:
|
else:
|
||||||
ref_value = ref_outputs_dict[name].numpy()
|
ref_value = ref_outputs_dict[name].numpy()
|
||||||
@@ -402,7 +422,7 @@ def ensure_model_and_config_inputs_match(
|
|||||||
|
|
||||||
:param model_inputs: :param config_inputs: :return:
|
:param model_inputs: :param config_inputs: :return:
|
||||||
"""
|
"""
|
||||||
if issubclass(type(model), PreTrainedModel):
|
if is_torch_available() and issubclass(type(model), PreTrainedModel):
|
||||||
forward_parameters = signature(model.forward).parameters
|
forward_parameters = signature(model.forward).parameters
|
||||||
else:
|
else:
|
||||||
forward_parameters = signature(model.call).parameters
|
forward_parameters = signature(model.call).parameters
|
||||||
|
|||||||
@@ -196,28 +196,19 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||||||
("m2m-100", "facebook/m2m100_418M"),
|
("m2m-100", "facebook/m2m100_418M"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||||
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||||
("albert", "hf-internal-testing/tiny-albert"),
|
("albert", "hf-internal-testing/tiny-albert"),
|
||||||
("bert", "bert-base-cased"),
|
("bert", "bert-base-cased"),
|
||||||
("ibert", "kssteven/ibert-roberta-base"),
|
|
||||||
("camembert", "camembert-base"),
|
|
||||||
("distilbert", "distilbert-base-cased"),
|
("distilbert", "distilbert-base-cased"),
|
||||||
("roberta", "roberta-base"),
|
("roberta", "roberta-base"),
|
||||||
("xlm-roberta", "xlm-roberta-base"),
|
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
|
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||||
("gpt2", "gpt2"),
|
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {}
|
||||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
|
||||||
}
|
|
||||||
|
|
||||||
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||||
("bart", "facebook/bart-base"),
|
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
|
||||||
("mbart", "sshleifer/tiny-mbart"),
|
|
||||||
("t5", "t5-small"),
|
|
||||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_models_to_test(export_models_list):
|
def _get_models_to_test(export_models_list):
|
||||||
@@ -312,13 +303,13 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
|
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
|
||||||
@slow
|
@slow
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_tensorflow_export_seq2seq_with_past(
|
def test_tensorflow_export_seq2seq_with_past(
|
||||||
|
|||||||
Reference in New Issue
Block a user