Export TensorFlow models to ONNX with dynamic input shapes (#19255)
* validate onnx models with a different input geometry than saved with * only test working features for now * simpler test skipping * rm TODO * expose batch_size/seq_length on vit * skip certain name, feature, framework parameterizations known to fail validation * Trigger CI * Trigger CI
This commit is contained in:
@@ -284,10 +284,12 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
|
||||
def _onnx_export(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu", framework="pt"
|
||||
):
|
||||
from transformers.onnx import export
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature, framework=framework)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = model_class.from_config(config)
|
||||
|
||||
@@ -296,6 +298,22 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
if model.__class__.__name__.startswith("Yolos") and device != "cpu":
|
||||
return
|
||||
|
||||
# ONNX inference fails with the following name, feature, framework parameterizations
|
||||
# See: https://github.com/huggingface/transformers/issues/19357
|
||||
if (name, feature, framework) in {
|
||||
("deberta-v2", "question-answering", "pt"),
|
||||
("deberta-v2", "multiple-choice", "pt"),
|
||||
("roformer", "multiple-choice", "pt"),
|
||||
("groupvit", "default", "pt"),
|
||||
("perceiver", "masked-lm", "pt"),
|
||||
("perceiver", "sequence-classification", "pt"),
|
||||
("perceiver", "image-classification", "pt"),
|
||||
("bert", "multiple-choice", "tf"),
|
||||
("camembert", "multiple-choice", "tf"),
|
||||
("roberta", "multiple-choice", "tf"),
|
||||
}:
|
||||
return
|
||||
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
|
||||
if is_torch_available():
|
||||
@@ -364,13 +382,13 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
@require_tf
|
||||
@require_vision
|
||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
|
||||
@slow
|
||||
@@ -378,7 +396,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_tensorflow_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
|
||||
|
||||
|
||||
class StableDropoutTestCase(TestCase):
|
||||
|
||||
Reference in New Issue
Block a user