Add Tensorflow handling of ONNX conversion (#13831)
* Add TensorFlow support for ONNX export * Change documentation to mention conversion with Tensorflow * Refactor export into export_pytorch and export_tensorflow * Check model's type instead of framework installation to choose between TF and Pytorch Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Alberto Bégué <alberto.begue@della.ai> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoConfig, AutoTokenizer, is_torch_available
|
||||
from transformers import AutoConfig, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
@@ -15,11 +15,11 @@ from transformers.onnx import (
|
||||
from transformers.onnx.config import OnnxConfigWithPast
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
if is_torch_available() or is_tf_available():
|
||||
from transformers.onnx.features import FeaturesManager
|
||||
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.testing_utils import require_onnx, require_torch, slow
|
||||
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow
|
||||
|
||||
|
||||
@require_onnx
|
||||
@@ -192,19 +192,44 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_DEFAULT_MODELS = {
|
||||
("albert", "hf-internal-testing/tiny-albert"),
|
||||
("bert", "bert-base-cased"),
|
||||
("ibert", "kssteven/ibert-roberta-base"),
|
||||
("camembert", "camembert-base"),
|
||||
("distilbert", "distilbert-base-cased"),
|
||||
("roberta", "roberta-base"),
|
||||
("xlm-roberta", "xlm-roberta-base"),
|
||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
|
||||
("gpt2", "gpt2"),
|
||||
("gpt-neo", "EleutherAI/gpt-neo-125M"),
|
||||
}
|
||||
|
||||
TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
||||
("bart", "facebook/bart-base"),
|
||||
("mbart", "sshleifer/tiny-mbart"),
|
||||
("t5", "t5-small"),
|
||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||
}
|
||||
|
||||
|
||||
def _get_models_to_test(export_models_list):
|
||||
models_to_test = []
|
||||
if not is_torch_available():
|
||||
# Returning some dummy test that should not be ever called because of the @require_torch decorator.
|
||||
if is_torch_available() or is_tf_available():
|
||||
for (name, model) in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return sorted(models_to_test)
|
||||
else:
|
||||
# Returning some dummy test that should not be ever called because of the @require_torch / @require_tf
|
||||
# decorators.
|
||||
# The reason for not returning an empty list is because parameterized.expand complains when it's empty.
|
||||
return [("dummy", "dummy", "dummy", "dummy", OnnxConfig.from_model_config)]
|
||||
for (name, model) in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return sorted(models_to_test)
|
||||
|
||||
|
||||
class OnnxExportTestCaseV2(TestCase):
|
||||
@@ -212,7 +237,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
from transformers.onnx import export
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@@ -246,13 +271,13 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pytorch_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@@ -260,4 +285,24 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_pytorch_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._pytorch_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_DEFAULT_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_with_past(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tensorflow_export_seq2seq_with_past(
|
||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||
):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
Reference in New Issue
Block a user