Add onnx support for VisionEncoderDecoder (#19254)

* Add onnx support for VisionEncoderDecoder

* Add onnx support for VisionEncoderDecoder

* Removed unused import

* Rename encoder hidden state

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docstrings and removed redundant code

* Added test function for enc-dec models

* Update doc string text

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* fixed code style

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Mohit Sharma
2022-10-10 18:50:19 +05:30
committed by GitHub
parent 298f6a98c2
commit 3080bb4754
7 changed files with 305 additions and 35 deletions

View File

@@ -161,7 +161,6 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
"""
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
with self.subTest(name):
# without past
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
@@ -220,6 +219,10 @@ PYTORCH_EXPORT_MODELS = {
("swin", "microsoft/swin-tiny-patch4-window7-224"),
}
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
("vision-encoder-decoder", "nlpconnect/vit-gpt2-image-captioning"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-560m"),
("gpt2", "gpt2"),
@@ -347,6 +350,70 @@ class OnnxExportTestCaseV2(TestCase):
except (RuntimeError, ValueError) as e:
self.fail(f"{name}, {feature} -> {e}")
def _onnx_export_encoder_decoder_models(
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"
):
from transformers import AutoFeatureExtractor, AutoTokenizer
from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature)
config = AutoConfig.from_pretrained(model_name)
model = model_class.from_config(config)
onnx_config = onnx_config_class_constructor(model.config)
if is_torch_available():
from transformers.utils import torch_version
if torch_version < onnx_config.torch_onnx_minimum_version:
pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
)
encoder_model = model.get_encoder()
decoder_model = model.get_decoder()
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
decoder_onnx_config = onnx_config.get_decoder_config(encoder_model.config, decoder_model.config, feature)
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
onnx_opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
with NamedTemporaryFile("w") as encoder_output:
onnx_inputs, onnx_outputs = export(
preprocessor, encoder_model, encoder_onnx_config, onnx_opset, Path(encoder_output.name), device=device
)
validate_model_outputs(
encoder_onnx_config,
preprocessor,
encoder_model,
Path(encoder_output.name),
onnx_outputs,
encoder_onnx_config.atol_for_validation,
)
preprocessor = AutoTokenizer.from_pretrained(model_name)
with NamedTemporaryFile("w") as decoder_output:
onnx_inputs, onnx_outputs = export(
preprocessor,
decoder_model,
decoder_onnx_config,
onnx_config.default_onnx_opset,
Path(decoder_output.name),
device=device,
)
validate_model_outputs(
decoder_onnx_config,
preprocessor,
decoder_model,
Path(decoder_output.name),
onnx_outputs,
decoder_onnx_config.atol_for_validation,
)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@@ -363,6 +430,28 @@ class OnnxExportTestCaseV2(TestCase):
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_encoder_decoder_models(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export_encoder_decoder_models(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_encoder_decoder_models_on_cuda(
self, test_name, name, model_name, feature, onnx_config_class_constructor
):
self._onnx_export_encoder_decoder_models(
test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda"
)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow
@require_torch