Add ONNX export for ViT (#15658)
* Add ONNX support for ViT * Refactor to use generic preprocessor * Add vision dep to tests * Extend ONNX slow tests to ViT * Add dummy image generator * Use model_type to determine modality * Add deprecation warnings for tokenizer argument * Add warning when overwriting the preprocessor * Add optional args to docstrings * Add minimum PyTorch version to OnnxConfig * Refactor OnnxConfig class variables from CONSTANT_NAME to snake_case * Add reasonable value for default atol Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -3,23 +3,25 @@ from tempfile import NamedTemporaryFile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
OnnxConfigWithPast,
|
||||
ParameterFormat,
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import OnnxConfigWithPast
|
||||
|
||||
|
||||
if is_torch_available() or is_tf_available():
|
||||
from transformers.onnx.features import FeaturesManager
|
||||
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
|
||||
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
|
||||
|
||||
|
||||
@require_onnx
|
||||
@@ -178,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
|
||||
("roberta", "roberta-base"),
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
("vit", "google/vit-base-patch16-224"),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
@@ -241,25 +244,38 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
from transformers.onnx import export
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
|
||||
# Useful for causal lm models that do not use pad tokens.
|
||||
if not getattr(config, "pad_token_id", None):
|
||||
config.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = model_class.from_config(config)
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.file_utils import torch_version
|
||||
|
||||
if torch_version < onnx_config.torch_onnx_minimum_version:
|
||||
pytest.skip(
|
||||
f"Skipping due to incompatible PyTorch version. Minimum required is {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
||||
)
|
||||
|
||||
# Check the modality of the inputs and instantiate the appropriate preprocessor
|
||||
if model.main_input_name == "input_ids":
|
||||
preprocessor = AutoTokenizer.from_pretrained(model_name)
|
||||
# Useful for causal lm models that do not use pad tokens.
|
||||
if not getattr(config, "pad_token_id", None):
|
||||
config.pad_token_id = preprocessor.eos_token_id
|
||||
elif model.main_input_name == "pixel_values":
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
|
||||
|
||||
with NamedTemporaryFile("w") as output:
|
||||
try:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
|
||||
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
|
||||
)
|
||||
validate_model_outputs(
|
||||
onnx_config,
|
||||
tokenizer,
|
||||
preprocessor,
|
||||
model,
|
||||
Path(output.name),
|
||||
onnx_outputs,
|
||||
@@ -271,6 +287,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@@ -291,6 +308,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
@require_vision
|
||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user