Add onnx support for VisionEncoderDecoder (#19254)
* Add onnx support for VisionEncoderDecoder * Add onnx support for VisionEncoderDecoder * Removed unused import * Rename encoder hidden state Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docstrings and removed redundant code * Added test function for enc-dec models * Update doc string text Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * fixed code style Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -161,7 +161,6 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
||||
"""
|
||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||
with self.subTest(name):
|
||||
|
||||
# without past
|
||||
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||
@@ -220,6 +219,10 @@ PYTORCH_EXPORT_MODELS = {
|
||||
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
|
||||
("vision-encoder-decoder", "nlpconnect/vit-gpt2-image-captioning"),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
("bloom", "bigscience/bloom-560m"),
|
||||
("gpt2", "gpt2"),
|
||||
@@ -347,6 +350,70 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
except (RuntimeError, ValueError) as e:
|
||||
self.fail(f"{name}, {feature} -> {e}")
|
||||
|
||||
def _onnx_export_encoder_decoder_models(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"
|
||||
):
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||
from transformers.onnx import export
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = model_class.from_config(config)
|
||||
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.utils import torch_version
|
||||
|
||||
if torch_version < onnx_config.torch_onnx_minimum_version:
|
||||
pytest.skip(
|
||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
||||
)
|
||||
|
||||
encoder_model = model.get_encoder()
|
||||
decoder_model = model.get_decoder()
|
||||
|
||||
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
|
||||
decoder_onnx_config = onnx_config.get_decoder_config(encoder_model.config, decoder_model.config, feature)
|
||||
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
|
||||
|
||||
onnx_opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
|
||||
|
||||
with NamedTemporaryFile("w") as encoder_output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor, encoder_model, encoder_onnx_config, onnx_opset, Path(encoder_output.name), device=device
|
||||
)
|
||||
validate_model_outputs(
|
||||
encoder_onnx_config,
|
||||
preprocessor,
|
||||
encoder_model,
|
||||
Path(encoder_output.name),
|
||||
onnx_outputs,
|
||||
encoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
|
||||
preprocessor = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
with NamedTemporaryFile("w") as decoder_output:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
decoder_onnx_config,
|
||||
onnx_config.default_onnx_opset,
|
||||
Path(decoder_output.name),
|
||||
device=device,
|
||||
)
|
||||
validate_model_outputs(
|
||||
decoder_onnx_config,
|
||||
preprocessor,
|
||||
decoder_model,
|
||||
Path(decoder_output.name),
|
||||
onnx_outputs,
|
||||
decoder_onnx_config.atol_for_validation,
|
||||
)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@@ -363,6 +430,28 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_rjieba
|
||||
def test_pytorch_export_encoder_decoder_models(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export_encoder_decoder_models(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_rjieba
|
||||
def test_pytorch_export_encoder_decoder_models_on_cuda(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export_encoder_decoder_models(
|
||||
test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda"
|
||||
)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user