Add support for Perceiver ONNX export (#17213)
* Start adding perceiver support for ONNX * Fix pad token bug for fast tokenizers * Fix formatting * Make get_preprocesor more opinionated (processor priority, otherwise tokenizer/feature extractor) * Clean docs format * Minor cleanup following @sgugger's comments * Fix typo in docs * Fix another docs typo * Fix one more typo in docs * Update src/transformers/onnx/utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/onnx/utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/onnx/utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5c17918fe4
commit
babeff5524
@@ -70,6 +70,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- mBART
|
- mBART
|
||||||
- MobileBERT
|
- MobileBERT
|
||||||
- OpenAI GPT-2
|
- OpenAI GPT-2
|
||||||
|
- Perceiver
|
||||||
- PLBart
|
- PLBart
|
||||||
- RoBERTa
|
- RoBERTa
|
||||||
- RoFormer
|
- RoFormer
|
||||||
|
|||||||
@@ -14,8 +14,15 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Perceiver model configuration"""
|
""" Perceiver model configuration"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Mapping, Optional, Union
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...feature_extraction_utils import FeatureExtractionMixin
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
|
from ...onnx.utils import compute_effective_axis_dimension
|
||||||
|
from ...tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
from ...utils import TensorType, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -172,3 +179,63 @@ class PerceiverConfig(PretrainedConfig):
|
|||||||
self.audio_samples_per_frame = audio_samples_per_frame
|
self.audio_samples_per_frame = audio_samples_per_frame
|
||||||
self.samples_per_patch = samples_per_patch
|
self.samples_per_patch = samples_per_patch
|
||||||
self.output_shape = output_shape
|
self.output_shape = output_shape
|
||||||
|
|
||||||
|
|
||||||
|
class PerceiverOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("inputs", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
num_choices: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
num_channels: int = 3,
|
||||||
|
image_width: int = 40,
|
||||||
|
image_height: int = 40,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
# copied from `transformers.onnx.config.OnnxConfig` and slightly altered/simplified
|
||||||
|
|
||||||
|
if isinstance(preprocessor, PreTrainedTokenizerBase):
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(
|
||||||
|
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
||||||
|
)
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||||
|
token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
|
||||||
|
seq_length = compute_effective_axis_dimension(
|
||||||
|
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
||||||
|
)
|
||||||
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
|
dummy_input = [" ".join(["a"]) * seq_length] * batch_size
|
||||||
|
inputs = dict(preprocessor(dummy_input, return_tensors=framework))
|
||||||
|
inputs["inputs"] = inputs.pop("input_ids")
|
||||||
|
return inputs
|
||||||
|
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
|
||||||
|
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
|
||||||
|
inputs = dict(preprocessor(images=dummy_input, return_tensors=framework))
|
||||||
|
inputs["inputs"] = inputs.pop("pixel_values")
|
||||||
|
return inputs
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
|
||||||
|
)
|
||||||
|
|||||||
@@ -2735,7 +2735,9 @@ def _check_or_build_spatial_positions(pos, index_dims, batch_size):
|
|||||||
"""
|
"""
|
||||||
if pos is None:
|
if pos is None:
|
||||||
pos = build_linear_positions(index_dims)
|
pos = build_linear_positions(index_dims)
|
||||||
pos = torch.broadcast_to(pos[None], (batch_size,) + pos.shape)
|
# equivalent to `torch.broadcast_to(pos[None], (batch_size,) + pos.shape)`
|
||||||
|
# but `torch.broadcast_to` cannot be converted to ONNX
|
||||||
|
pos = pos[None].expand((batch_size,) + pos.shape)
|
||||||
pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
|
pos = torch.reshape(pos, [batch_size, np.prod(index_dims), -1])
|
||||||
else:
|
else:
|
||||||
# Just a warning label: you probably don't want your spatial features to
|
# Just a warning label: you probably don't want your spatial features to
|
||||||
@@ -2840,7 +2842,8 @@ class PerceiverEmbeddingDecoder(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, embedding_layer: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, seq_len, d_model = hidden_states.shape
|
batch_size, seq_len, d_model = hidden_states.shape
|
||||||
output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.T) # Flatten batch dim
|
# Flatten batch dim
|
||||||
|
output = torch.matmul(hidden_states.reshape([-1, d_model]), embedding_layer.weight.transpose(0, 1))
|
||||||
output = output + self.bias
|
output = output + self.bias
|
||||||
|
|
||||||
return output.reshape([batch_size, seq_len, self.vocab_size])
|
return output.reshape([batch_size, seq_len, self.vocab_size])
|
||||||
@@ -3166,9 +3169,9 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
|
|||||||
if self.prep_type != "patches":
|
if self.prep_type != "patches":
|
||||||
# move channels to last dimension, as the _build_network_inputs method below expects this
|
# move channels to last dimension, as the _build_network_inputs method below expects this
|
||||||
if inputs.ndim == 4:
|
if inputs.ndim == 4:
|
||||||
inputs = torch.moveaxis(inputs, 1, -1)
|
inputs = torch.permute(inputs, (0, 2, 3, 1))
|
||||||
elif inputs.ndim == 5:
|
elif inputs.ndim == 5:
|
||||||
inputs = torch.moveaxis(inputs, 2, -1)
|
inputs = torch.permute(inputs, (0, 1, 3, 4, 2))
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported data format for conv1x1.")
|
raise ValueError("Unsupported data format for conv1x1.")
|
||||||
|
|
||||||
|
|||||||
@@ -15,9 +15,8 @@
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ..models.auto import AutoConfig, AutoFeatureExtractor, AutoTokenizer
|
from ..models.auto import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
||||||
from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
|
from ..onnx.utils import get_preprocessor
|
||||||
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .convert import export, validate_model_outputs
|
from .convert import export, validate_model_outputs
|
||||||
from .features import FeaturesManager
|
from .features import FeaturesManager
|
||||||
@@ -43,6 +42,13 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
|
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
|
||||||
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
|
parser.add_argument("--cache_dir", type=str, default=None, help="Path indicating where to store cache.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocessor",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "tokenizer", "feature_extractor", "processor"],
|
||||||
|
default="auto",
|
||||||
|
help="Which type of preprocessor to use. 'auto' tries to automatically detect it.",
|
||||||
|
)
|
||||||
|
|
||||||
# Retrieve CLI arguments
|
# Retrieve CLI arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -51,15 +57,17 @@ def main():
|
|||||||
if not args.output.parent.exists():
|
if not args.output.parent.exists():
|
||||||
args.output.parent.mkdir(parents=True)
|
args.output.parent.mkdir(parents=True)
|
||||||
|
|
||||||
# Check the modality of the inputs and instantiate the appropriate preprocessor
|
# Instantiate the appropriate preprocessor
|
||||||
# TODO(lewtun): Refactor this as a function if we need to check modalities elsewhere as well
|
if args.preprocessor == "auto":
|
||||||
config = AutoConfig.from_pretrained(args.model)
|
preprocessor = get_preprocessor(args.model)
|
||||||
if config.model_type in TOKENIZER_MAPPING_NAMES:
|
elif args.preprocessor == "tokenizer":
|
||||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||||
elif config.model_type in FEATURE_EXTRACTOR_MAPPING_NAMES:
|
elif args.preprocessor == "feature_extractor":
|
||||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||||
|
elif args.preprocessor == "processor":
|
||||||
|
preprocessor = AutoProcessor.from_pretrained(args.model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model type: {config.model_type}")
|
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
|
||||||
|
|
||||||
# Allocate the model
|
# Allocate the model
|
||||||
model = FeaturesManager.get_model_from_feature(
|
model = FeaturesManager.get_model_from_feature(
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from ..models.m2m_100 import M2M100OnnxConfig
|
|||||||
from ..models.marian import MarianOnnxConfig
|
from ..models.marian import MarianOnnxConfig
|
||||||
from ..models.mbart import MBartOnnxConfig
|
from ..models.mbart import MBartOnnxConfig
|
||||||
from ..models.mobilebert import MobileBertOnnxConfig
|
from ..models.mobilebert import MobileBertOnnxConfig
|
||||||
|
from ..models.perceiver.configuration_perceiver import PerceiverOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
from ..models.roformer import RoFormerOnnxConfig
|
from ..models.roformer import RoFormerOnnxConfig
|
||||||
from ..models.squeezebert import SqueezeBertOnnxConfig
|
from ..models.squeezebert import SqueezeBertOnnxConfig
|
||||||
@@ -332,6 +333,12 @@ class FeaturesManager:
|
|||||||
"m2m-100": supported_features_mapping(
|
"m2m-100": supported_features_mapping(
|
||||||
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
|
"default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig
|
||||||
),
|
),
|
||||||
|
"perceiver": supported_features_mapping(
|
||||||
|
"image-classification",
|
||||||
|
"masked-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
onnx_config_cls=PerceiverOnnxConfig,
|
||||||
|
),
|
||||||
"roberta": supported_features_mapping(
|
"roberta": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
@@ -516,3 +523,18 @@ class FeaturesManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|
return model.config.model_type, FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|
||||||
|
|
||||||
|
def get_config(model_type: str, feature: str) -> OnnxConfig:
|
||||||
|
"""
|
||||||
|
Gets the OnnxConfig for a model_type and feature combination.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type (`str`):
|
||||||
|
The model type to retrieve the config for.
|
||||||
|
feature (`str`):
|
||||||
|
The feature to retrieve the config for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`OnnxConfig`: config for the combination
|
||||||
|
"""
|
||||||
|
return FeaturesManager._SUPPORTED_MODEL_TYPE[model_type][feature]
|
||||||
|
|||||||
@@ -14,6 +14,9 @@
|
|||||||
|
|
||||||
from ctypes import c_float, sizeof
|
from ctypes import c_float, sizeof
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from .. import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
class ParameterFormat(Enum):
|
class ParameterFormat(Enum):
|
||||||
@@ -61,3 +64,41 @@ def compute_serialized_parameters_size(num_parameters: int, dtype: ParameterForm
|
|||||||
Size (in byte) taken to save all the parameters
|
Size (in byte) taken to save all the parameters
|
||||||
"""
|
"""
|
||||||
return num_parameters * dtype.size
|
return num_parameters * dtype.size
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocessor(model_name: str) -> Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]:
|
||||||
|
"""
|
||||||
|
Gets a preprocessor (tokenizer, feature extractor or processor) that is available for `model_name`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (`str`): Name of the model for which a preprocessor are loaded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Optional[Union[AutoTokenizer, AutoFeatureExtractor, AutoProcessor]]`:
|
||||||
|
If a processor is found, it is returned. Otherwise, if a tokenizer or a feature extractor exists, it is
|
||||||
|
returned. If both a tokenizer and a feature extractor exist, an error is raised. The function returns
|
||||||
|
`None` if no preprocessor is found.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return AutoProcessor.from_pretrained(model_name)
|
||||||
|
except (ValueError, OSError, KeyError):
|
||||||
|
tokenizer, feature_extractor = None, None
|
||||||
|
try:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
except (OSError, KeyError):
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
||||||
|
except (OSError, KeyError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if tokenizer is not None and feature_extractor is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Couldn't auto-detect preprocessor for {model_name}. Found both a tokenizer and a feature extractor."
|
||||||
|
)
|
||||||
|
elif tokenizer is None and feature_extractor is None:
|
||||||
|
return None
|
||||||
|
elif tokenizer is not None:
|
||||||
|
return tokenizer
|
||||||
|
else:
|
||||||
|
return feature_extractor
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, is_tf_available, is_torch_available
|
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
|
||||||
from transformers.onnx import (
|
from transformers.onnx import (
|
||||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||||
OnnxConfig,
|
OnnxConfig,
|
||||||
@@ -15,7 +15,11 @@ from transformers.onnx import (
|
|||||||
export,
|
export,
|
||||||
validate_model_outputs,
|
validate_model_outputs,
|
||||||
)
|
)
|
||||||
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
|
from transformers.onnx.utils import (
|
||||||
|
compute_effective_axis_dimension,
|
||||||
|
compute_serialized_parameters_size,
|
||||||
|
get_preprocessor,
|
||||||
|
)
|
||||||
from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
|
from transformers.testing_utils import require_onnx, require_rjieba, require_tf, require_torch, require_vision, slow
|
||||||
|
|
||||||
|
|
||||||
@@ -189,6 +193,8 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("deit", "facebook/deit-small-patch16-224"),
|
("deit", "facebook/deit-small-patch16-224"),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
("data2vec-text", "facebook/data2vec-text-base"),
|
("data2vec-text", "facebook/data2vec-text-base"),
|
||||||
|
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
|
||||||
|
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
@@ -226,10 +232,15 @@ TENSORFLOW_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {}
|
|||||||
def _get_models_to_test(export_models_list):
|
def _get_models_to_test(export_models_list):
|
||||||
models_to_test = []
|
models_to_test = []
|
||||||
if is_torch_available() or is_tf_available():
|
if is_torch_available() or is_tf_available():
|
||||||
for name, model in export_models_list:
|
for name, model, *features in export_models_list:
|
||||||
for feature, onnx_config_class_constructor in FeaturesManager.get_supported_features_for_model_type(
|
if features:
|
||||||
name
|
feature_config_mapping = {
|
||||||
).items():
|
feature: FeaturesManager.get_config(name, feature) for _ in features for feature in _
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
feature_config_mapping = FeaturesManager.get_supported_features_for_model_type(name)
|
||||||
|
|
||||||
|
for feature, onnx_config_class_constructor in feature_config_mapping.items():
|
||||||
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
models_to_test.append((f"{name}_{feature}", name, model, feature, onnx_config_class_constructor))
|
||||||
return sorted(models_to_test)
|
return sorted(models_to_test)
|
||||||
else:
|
else:
|
||||||
@@ -261,16 +272,11 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check the modality of the inputs and instantiate the appropriate preprocessor
|
preprocessor = get_preprocessor(model_name)
|
||||||
if model.main_input_name == "input_ids":
|
|
||||||
preprocessor = AutoTokenizer.from_pretrained(model_name)
|
# Useful for causal lm models that do not use pad tokens.
|
||||||
# Useful for causal lm models that do not use pad tokens.
|
if isinstance(preprocessor, PreTrainedTokenizerBase) and not getattr(config, "pad_token_id", None):
|
||||||
if not getattr(config, "pad_token_id", None):
|
config.pad_token_id = preprocessor.eos_token_id
|
||||||
config.pad_token_id = preprocessor.eos_token_id
|
|
||||||
elif model.main_input_name == "pixel_values":
|
|
||||||
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model input name: {model.main_input_name}")
|
|
||||||
|
|
||||||
with NamedTemporaryFile("w") as output:
|
with NamedTemporaryFile("w") as output:
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user