Add support for Perceiver ONNX export (#17213)
* Start adding perceiver support for ONNX * Fix pad token bug for fast tokenizers * Fix formatting * Make get_preprocesor more opinionated (processor priority, otherwise tokenizer/feature extractor) * Clean docs format * Minor cleanup following @sgugger's comments * Fix typo in docs * Fix another docs typo * Fix one more typo in docs * Update src/transformers/onnx/utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/onnx/utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/onnx/utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5c17918fe4
commit
babeff5524
@@ -6,7 +6,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
||||
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
@@ -15,7 +15,11 @@ from transformers.onnx import (
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
||||
from transformers.onnx.utils import (
|
||||
compute_effective_axis_dimension,
|
||||
compute_serialized_parameters_size,
|
||||
get_preprocessor,
|
||||
)
|
||||
from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
|
||||
|
||||
|
||||
@@ -189,6 +193,8 @@ PYTORCH_EXPORT_MODELS = {
|
||||
("deit", "facebook/deit-small-patch16-224"),
|
||||
("beit", "microsoft/beit-base-patch16-224"),
|
||||
("data2vec-text", "facebook/data2vec-text-base"),
|
||||
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
|
||||
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
||||
}
|
||||
|
||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||
@@ -226,10 +232,15 @@ TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
|
||||
def _get_models_to_test(export_models_list):
|
||||
models_to_test = []
|
||||
if is_torch_available() or is_tf_available():
|
||||
for name, model in export_models_list:
|
||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
||||
name
|
||||
).items():
|
||||
for name, model, *features in export_models_list:
|
||||
if features:
|
||||
feature_config_mapping = {
|
||||
feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
|
||||
}
|
||||
else:
|
||||
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
|
||||
|
||||
for feature, onnx_config_class_constructor in feature_config_mapping.items():
|
||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||
return sorted(models_to_test)
|
||||
else:
|
||||
@@ -261,16 +272,11 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
||||
)
|
||||
|
||||
# Check the modality of the inputs and instantiate the appropriate preprocessor
|
||||
if model.main_input_name == "input_ids":
|
||||
preprocessor = AutoTokenizer.from_pretrained(model_name)
|
||||
# Useful for causal lm models that do not use pad tokens.
|
||||
if not getattr(config, "pad_token_id", None):
|
||||
config.pad_token_id = preprocessor.eos_token_id
|
||||
elif model.main_input_name == "pixel_values":
|
||||
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
|
||||
preprocessor = get_preprocessor(model_name)
|
||||
|
||||
# Useful for causal lm models that do not use pad tokens.
|
||||
if isinstance(preprocessor, PreTrainedTokenizerBase) and not getattr(config, "pad_token_id", None):
|
||||
config.pad_token_id = preprocessor.eos_token_id
|
||||
|
||||
with NamedTemporaryFile("w") as output:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user