Export TensorFlow models to ONNX with dynamic input shapes (#19255)
* validate onnx models with a different input geometry than saved with * only test working features for now * simpler test skipping * rm TODO * expose batch_size/seq_length on vit * skip certain name, feature, framework parameterizations known to fail validation * Trigger CI * Trigger CI
This commit is contained in:
@@ -355,11 +355,17 @@ class CLIPOnnxConfig(OnnxConfig):
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
processor: "ProcessorMixin",
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
|
||||
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
|
||||
text_input_dict = super().generate_dummy_inputs(
|
||||
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
|
||||
)
|
||||
image_input_dict = super().generate_dummy_inputs(
|
||||
processor.feature_extractor, batch_size=batch_size, framework=framework
|
||||
)
|
||||
return {**text_input_dict, **image_input_dict}
|
||||
|
||||
@property
|
||||
|
||||
@@ -381,11 +381,17 @@ class GroupViTOnnxConfig(OnnxConfig):
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
processor: "ProcessorMixin",
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
|
||||
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
|
||||
text_input_dict = super().generate_dummy_inputs(
|
||||
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
|
||||
)
|
||||
image_input_dict = super().generate_dummy_inputs(
|
||||
processor.feature_extractor, batch_size=batch_size, framework=framework
|
||||
)
|
||||
return {**text_input_dict, **image_input_dict}
|
||||
|
||||
@property
|
||||
|
||||
@@ -372,11 +372,17 @@ class OwlViTOnnxConfig(OnnxConfig):
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
processor: "ProcessorMixin",
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
framework: Optional["TensorType"] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
|
||||
text_input_dict = super().generate_dummy_inputs(processor.tokenizer, framework=framework)
|
||||
image_input_dict = super().generate_dummy_inputs(processor.feature_extractor, framework=framework)
|
||||
text_input_dict = super().generate_dummy_inputs(
|
||||
processor.tokenizer, batch_size=batch_size, seq_length=seq_length, framework=framework
|
||||
)
|
||||
image_input_dict = super().generate_dummy_inputs(
|
||||
processor.feature_extractor, batch_size=batch_size, framework=framework
|
||||
)
|
||||
return {**text_input_dict, **image_input_dict}
|
||||
|
||||
@property
|
||||
|
||||
@@ -262,7 +262,9 @@ def export_tensorflow(
|
||||
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
|
||||
onnx_outputs = list(config.outputs.keys())
|
||||
|
||||
input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
|
||||
input_signature = [
|
||||
tf.TensorSpec([None] * tensor.ndim, dtype=tensor.dtype, name=key) for key, tensor in model_inputs.items()
|
||||
]
|
||||
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
|
||||
onnx.save(onnx_model, output.as_posix())
|
||||
config.restore_ops()
|
||||
@@ -363,12 +365,22 @@ def validate_model_outputs(
|
||||
logger.info("Overwriting the `preprocessor` argument with `tokenizer` to generate dummmy inputs.")
|
||||
preprocessor = tokenizer
|
||||
|
||||
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||
# generate inputs with a different batch_size and seq_len that was used for conversion to properly test
|
||||
# dynamic input shapes.
|
||||
if is_torch_available() and issubclass(type(reference_model), PreTrainedModel):
|
||||
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH)
|
||||
reference_model_inputs = config.generate_dummy_inputs(
|
||||
preprocessor,
|
||||
batch_size=config.default_fixed_batch + 1,
|
||||
seq_length=config.default_fixed_sequence + 1,
|
||||
framework=TensorType.PYTORCH,
|
||||
)
|
||||
else:
|
||||
reference_model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.TENSORFLOW)
|
||||
reference_model_inputs = config.generate_dummy_inputs(
|
||||
preprocessor,
|
||||
batch_size=config.default_fixed_batch + 1,
|
||||
seq_length=config.default_fixed_sequence + 1,
|
||||
framework=TensorType.TENSORFLOW,
|
||||
)
|
||||
|
||||
# Create ONNX Runtime session
|
||||
options = SessionOptions()
|
||||
|
||||
@@ -284,10 +284,12 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
|
||||
def _onnx_export(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu", framework="pt"
|
||||
):
|
||||
from transformers.onnx import export
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature, framework=framework)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = model_class.from_config(config)
|
||||
|
||||
@@ -296,6 +298,22 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
if model.__class__.__name__.startswith("Yolos") and device != "cpu":
|
||||
return
|
||||
|
||||
# ONNX inference fails with the following name, feature, framework parameterizations
|
||||
# See: https://github.com/huggingface/transformers/issues/19357
|
||||
if (name, feature, framework) in {
|
||||
("deberta-v2", "question-answering", "pt"),
|
||||
("deberta-v2", "multiple-choice", "pt"),
|
||||
("roformer", "multiple-choice", "pt"),
|
||||
("groupvit", "default", "pt"),
|
||||
("perceiver", "masked-lm", "pt"),
|
||||
("perceiver", "sequence-classification", "pt"),
|
||||
("perceiver", "image-classification", "pt"),
|
||||
("bert", "multiple-choice", "tf"),
|
||||
("camembert", "multiple-choice", "tf"),
|
||||
("roberta", "multiple-choice", "tf"),
|
||||
}:
|
||||
return
|
||||
|
||||
onnx_config = onnx_config_class_constructor(model.config)
|
||||
|
||||
if is_torch_available():
|
||||
@@ -364,13 +382,13 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
@require_tf
|
||||
@require_vision
|
||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS), skip_on_empty=True)
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS), skip_on_empty=True)
|
||||
@slow
|
||||
@@ -378,7 +396,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_tensorflow_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, framework="tf")
|
||||
|
||||
|
||||
class StableDropoutTestCase(TestCase):
|
||||
|
||||
Reference in New Issue
Block a user