Add ONNX export for ViT (#15658)

* Add ONNX support for ViT

* Refactor to use generic preprocessor

* Add vision dep to tests

* Extend ONNX slow tests to ViT

* Add dummy image generator

* Use model_type to determine modality

* Add deprecation warnings for tokenizer argument

* Add warning when overwriting the preprocessor

* Add optional args to docstrings

* Add minimum PyTorch version to OnnxConfig

* Refactor OnnxConfig class variables from CONSTANT_NAME to snake_case

* Add reasonable value for default atol

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
lewtun
2022-03-09 17:36:59 +01:00
committed by GitHub
parent b7fa1e3dee
commit 50dd314d93
12 changed files with 270 additions and 91 deletions

View File

@@ -3,23 +3,25 @@ from tempfile import NamedTemporaryFile
from unittest import TestCase
from unittest.mock import patch
import pytest
from parameterized import parameterized
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
from transformers.onnx import (
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
OnnxConfig,
OnnxConfigWithPast,
ParameterFormat,
export,
validate_model_outputs,
)
from transformers.onnx.config import OnnxConfigWithPast
if is_torch_available() or is_tf_available():
from transformers.onnx.features import FeaturesManager
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
from transformers.testing_utils import require_onnx, require_tf, require_torch, require_vision, slow
@require_onnx
@@ -178,6 +180,7 @@ PYTORCH_EXPORT_MODELS = {
("roberta", "roberta-base"),
("xlm-roberta", "xlm-roberta-base"),
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
@@ -241,25 +244,38 @@ class OnnxExportTestCaseV2(TestCase):
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
from transformers.onnx import export
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
# Useful for causal lm models that do not use pad tokens.
if not getattr(config, "pad_token_id", None):
config.pad_token_id = tokenizer.eos_token_id
model_class = FeaturesManager.get_model_class_for_feature(feature)
config = AutoConfig.from_pretrained(model_name)
model = model_class.from_config(config)
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.file_utils import torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
pytest.skip(
f"Skipping due to incompatible PyTorch version. Minimum required is {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
)
# Check the modality of the inputs and instantiate the appropriate preprocessor
if model.main_input_name == "input_ids":
preprocessor = AutoTokenizer.from_pretrained(model_name)
# Useful for causal lm models that do not use pad tokens.
if not getattr(config, "pad_token_id", None):
config.pad_token_id = preprocessor.eos_token_id
elif model.main_input_name == "pixel_values":
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
else:
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
with NamedTemporaryFile("w") as output:
try:
onnx_inputs, onnx_outputs = export(
tokenizer, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
)
validate_model_outputs(
onnx_config,
tokenizer,
preprocessor,
model,
Path(output.name),
onnx_outputs,
@@ -271,6 +287,7 @@ class OnnxExportTestCaseV2(TestCase):
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@@ -291,6 +308,7 @@ class OnnxExportTestCaseV2(TestCase):
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
@slow
@require_tf
@require_vision
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)