Add onnx support for VisionEncoderDecoder (#19254)
* Add onnx support for VisionEncoderDecoder * Add onnx support for VisionEncoderDecoder * Removed unused import * Rename encoder hidden state Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docstrings and removed redundant code * Added test function for enc-dec models * Update doc string text Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * fixed code style Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -96,6 +96,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- SqueezeBERT
|
- SqueezeBERT
|
||||||
- Swin Transformer
|
- Swin Transformer
|
||||||
- T5
|
- T5
|
||||||
|
- Vision Encoder decoder
|
||||||
- ViT
|
- ViT
|
||||||
- XLM
|
- XLM
|
||||||
- XLM-RoBERTa
|
- XLM-RoBERTa
|
||||||
@@ -294,6 +295,13 @@ that can be used for fast autoregressive decoding.
|
|||||||
|
|
||||||
</Tip>
|
</Tip>
|
||||||
|
|
||||||
|
<Tip>
|
||||||
|
|
||||||
|
For `VisionEncoderDecoder` type models, the encoder and decoder parts are
|
||||||
|
exported separately as two ONNX files named `encoder_model.onnx` and `decoder_model.onnx` respectively.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
|
||||||
## Exporting a model for an unsupported architecture
|
## Exporting a model for an unsupported architecture
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ from ...utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_import_structure = {"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig"]}
|
_import_structure = {
|
||||||
|
"configuration_vision_encoder_decoder": ["VisionEncoderDecoderConfig", "VisionEncoderDecoderOnnxConfig"]
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
@@ -54,7 +56,7 @@ else:
|
|||||||
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
|
_import_structure["modeling_flax_vision_encoder_decoder"] = ["FlaxVisionEncoderDecoderModel"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig
|
from .configuration_vision_encoder_decoder import VisionEncoderDecoderConfig, VisionEncoderDecoderOnnxConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
|||||||
@@ -15,12 +15,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import TYPE_CHECKING, Any, Mapping, Optional, OrderedDict
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx import OnnxConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..auto.configuration_auto import AutoConfig
|
from ..auto.configuration_auto import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ... import PreTrainedTokenizerBase, TensorType
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -119,3 +126,97 @@ class VisionEncoderDecoderConfig(PretrainedConfig):
|
|||||||
output["decoder"] = self.decoder.to_dict()
|
output["decoder"] = self.decoder.to_dict()
|
||||||
output["model_type"] = self.__class__.model_type
|
output["model_type"] = self.__class__.model_type
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionEncoderDecoderEncoderOnnxConfig(OnnxConfig):
|
||||||
|
torch_onnx_minimum_version = version.parse("1.11")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict({"last_hidden_state": {0: "batch", 1: "encoder_sequence"}})
|
||||||
|
|
||||||
|
|
||||||
|
class VisionEncoderDecoderDecoderOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
common_inputs = OrderedDict()
|
||||||
|
common_inputs["input_ids"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
common_inputs["attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||||
|
common_inputs["encoder_hidden_states"] = {0: "batch", 1: "encoder_sequence"}
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional["TensorType"] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
common_inputs = OrderedDict()
|
||||||
|
|
||||||
|
dummy_input = super().generate_dummy_inputs(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
|
batch, encoder_sequence = dummy_input["input_ids"].shape
|
||||||
|
encoder_hidden_states_shape = (batch, encoder_sequence, self._config.encoder_hidden_size)
|
||||||
|
common_inputs["input_ids"] = dummy_input.pop("input_ids")
|
||||||
|
common_inputs["attention_mask"] = dummy_input.pop("attention_mask")
|
||||||
|
common_inputs["encoder_hidden_states"] = torch.zeros(encoder_hidden_states_shape)
|
||||||
|
|
||||||
|
return common_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class VisionEncoderDecoderOnnxConfig(OnnxConfig):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_encoder_config(self, encoder_config: PretrainedConfig) -> OnnxConfig:
|
||||||
|
r"""
|
||||||
|
Returns ONNX encoder config for `VisionEncoderDecoder` model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_config (`PretrainedConfig`):
|
||||||
|
The encoder model's configuration to use when exporting to ONNX.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`VisionEncoderDecoderEncoderOnnxConfig`]: An instance of the ONNX configuration object
|
||||||
|
"""
|
||||||
|
return VisionEncoderDecoderEncoderOnnxConfig(encoder_config)
|
||||||
|
|
||||||
|
def get_decoder_config(
|
||||||
|
self, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig, feature: str = "default"
|
||||||
|
) -> OnnxConfig:
|
||||||
|
r"""
|
||||||
|
Returns ONNX decoder config for `VisionEncoderDecoder` model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_config (`PretrainedConfig`):
|
||||||
|
The encoder model's configuration to use when exporting to ONNX.
|
||||||
|
decoder_config (`PretrainedConfig`):
|
||||||
|
The decoder model's configuration to use when exporting to ONNX
|
||||||
|
feature (`str`, *optional*):
|
||||||
|
The type of feature to export the model with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`VisionEncoderDecoderDecoderOnnxConfig`]: An instance of the ONNX configuration object.
|
||||||
|
"""
|
||||||
|
decoder_config.encoder_hidden_size = encoder_config.hidden_size
|
||||||
|
return VisionEncoderDecoderDecoderOnnxConfig(decoder_config, feature)
|
||||||
|
|||||||
@@ -22,6 +22,9 @@ from .convert import export, validate_model_outputs
|
|||||||
from .features import FeaturesManager
|
from .features import FeaturesManager
|
||||||
|
|
||||||
|
|
||||||
|
ENCODER_DECODER_MODELS = ["vision-encoder-decoder"]
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser("Hugging Face Transformers ONNX exporter")
|
parser = ArgumentParser("Hugging Face Transformers ONNX exporter")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -65,48 +68,110 @@ def main():
|
|||||||
if not args.output.parent.exists():
|
if not args.output.parent.exists():
|
||||||
args.output.parent.mkdir(parents=True)
|
args.output.parent.mkdir(parents=True)
|
||||||
|
|
||||||
# Instantiate the appropriate preprocessor
|
|
||||||
if args.preprocessor == "auto":
|
|
||||||
preprocessor = get_preprocessor(args.model)
|
|
||||||
elif args.preprocessor == "tokenizer":
|
|
||||||
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
|
||||||
elif args.preprocessor == "feature_extractor":
|
|
||||||
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
|
||||||
elif args.preprocessor == "processor":
|
|
||||||
preprocessor = AutoProcessor.from_pretrained(args.model)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
|
|
||||||
|
|
||||||
# Allocate the model
|
# Allocate the model
|
||||||
model = FeaturesManager.get_model_from_feature(
|
model = FeaturesManager.get_model_from_feature(
|
||||||
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
|
args.feature, args.model, framework=args.framework, cache_dir=args.cache_dir
|
||||||
)
|
)
|
||||||
|
|
||||||
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=args.feature)
|
||||||
onnx_config = model_onnx_config(model.config)
|
onnx_config = model_onnx_config(model.config)
|
||||||
|
|
||||||
# Ensure the requested opset is sufficient
|
if model_kind in ENCODER_DECODER_MODELS:
|
||||||
if args.opset is None:
|
encoder_model = model.get_encoder()
|
||||||
args.opset = onnx_config.default_onnx_opset
|
decoder_model = model.get_decoder()
|
||||||
|
|
||||||
if args.opset < onnx_config.default_onnx_opset:
|
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
|
||||||
raise ValueError(
|
decoder_onnx_config = onnx_config.get_decoder_config(
|
||||||
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
encoder_model.config, decoder_model.config, feature=args.feature
|
||||||
f"At least {onnx_config.default_onnx_opset} is required."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
onnx_inputs, onnx_outputs = export(
|
if args.opset is None:
|
||||||
preprocessor,
|
args.opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
|
||||||
model,
|
|
||||||
onnx_config,
|
|
||||||
args.opset,
|
|
||||||
args.output,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.atol is None:
|
if args.opset < min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset):
|
||||||
args.atol = onnx_config.atol_for_validation
|
raise ValueError(
|
||||||
|
f"Opset {args.opset} is not sufficient to export {model_kind}. At least "
|
||||||
|
f" {min(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)} is required."
|
||||||
|
)
|
||||||
|
|
||||||
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
|
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||||
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
|
||||||
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
preprocessor,
|
||||||
|
encoder_model,
|
||||||
|
encoder_onnx_config,
|
||||||
|
args.opset,
|
||||||
|
args.output.parent.joinpath("encoder_model.onnx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_model_outputs(
|
||||||
|
encoder_onnx_config,
|
||||||
|
preprocessor,
|
||||||
|
encoder_model,
|
||||||
|
args.output.parent.joinpath("encoder_model.onnx"),
|
||||||
|
onnx_outputs,
|
||||||
|
args.atol if args.atol else encoder_onnx_config.atol_for_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
|
||||||
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
preprocessor,
|
||||||
|
decoder_model,
|
||||||
|
decoder_onnx_config,
|
||||||
|
args.opset,
|
||||||
|
args.output.parent.joinpath("decoder_model.onnx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_model_outputs(
|
||||||
|
decoder_onnx_config,
|
||||||
|
preprocessor,
|
||||||
|
decoder_model,
|
||||||
|
args.output.parent.joinpath("decoder_model.onnx"),
|
||||||
|
onnx_outputs,
|
||||||
|
args.atol if args.atol else decoder_onnx_config.atol_for_validation,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"All good, model saved at: {args.output.parent.joinpath('encoder_model.onnx').as_posix()},"
|
||||||
|
f" {args.output.parent.joinpath('decoder_model.onnx').as_posix()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Instantiate the appropriate preprocessor
|
||||||
|
if args.preprocessor == "auto":
|
||||||
|
preprocessor = get_preprocessor(args.model)
|
||||||
|
elif args.preprocessor == "tokenizer":
|
||||||
|
preprocessor = AutoTokenizer.from_pretrained(args.model)
|
||||||
|
elif args.preprocessor == "feature_extractor":
|
||||||
|
preprocessor = AutoFeatureExtractor.from_pretrained(args.model)
|
||||||
|
elif args.preprocessor == "processor":
|
||||||
|
preprocessor = AutoProcessor.from_pretrained(args.model)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown preprocessor type '{args.preprocessor}'")
|
||||||
|
|
||||||
|
# Ensure the requested opset is sufficient
|
||||||
|
if args.opset is None:
|
||||||
|
args.opset = onnx_config.default_onnx_opset
|
||||||
|
|
||||||
|
if args.opset < onnx_config.default_onnx_opset:
|
||||||
|
raise ValueError(
|
||||||
|
f"Opset {args.opset} is not sufficient to export {model_kind}. "
|
||||||
|
f"At least {onnx_config.default_onnx_opset} is required."
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
preprocessor,
|
||||||
|
model,
|
||||||
|
onnx_config,
|
||||||
|
args.opset,
|
||||||
|
args.output,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.atol is None:
|
||||||
|
args.atol = onnx_config.atol_for_validation
|
||||||
|
|
||||||
|
validate_model_outputs(onnx_config, preprocessor, model, args.output, onnx_outputs, args.atol)
|
||||||
|
logger.info(f"All good, model saved at: {args.output.as_posix()}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class OnnxConfig(ABC):
|
|||||||
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
"seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
|
||||||
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
|
||||||
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
|
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
|
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
|
||||||
@@ -451,7 +452,6 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
|||||||
is_pair: bool = False,
|
is_pair: bool = False,
|
||||||
framework: Optional[TensorType] = None,
|
framework: Optional[TensorType] = None,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
# TODO: should we set seq_length = 1 when self.use_past = True?
|
# TODO: should we set seq_length = 1 when self.use_past = True?
|
||||||
common_inputs = super().generate_dummy_inputs(
|
common_inputs = super().generate_dummy_inputs(
|
||||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
@@ -577,7 +577,6 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
|
|||||||
is_pair: bool = False,
|
is_pair: bool = False,
|
||||||
framework: Optional[TensorType] = None,
|
framework: Optional[TensorType] = None,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ if is_torch_available():
|
|||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
|
AutoModelForVision2Seq,
|
||||||
)
|
)
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from transformers.models.auto import (
|
from transformers.models.auto import (
|
||||||
@@ -98,6 +99,7 @@ class FeaturesManager:
|
|||||||
"image-segmentation": AutoModelForImageSegmentation,
|
"image-segmentation": AutoModelForImageSegmentation,
|
||||||
"masked-im": AutoModelForMaskedImageModeling,
|
"masked-im": AutoModelForMaskedImageModeling,
|
||||||
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
"semantic-segmentation": AutoModelForSemanticSegmentation,
|
||||||
|
"vision2seq-lm": AutoModelForVision2Seq,
|
||||||
}
|
}
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_TASKS_TO_TF_AUTOMODELS = {
|
_TASKS_TO_TF_AUTOMODELS = {
|
||||||
@@ -481,6 +483,9 @@ class FeaturesManager:
|
|||||||
"seq2seq-lm-with-past",
|
"seq2seq-lm-with-past",
|
||||||
onnx_config_cls="models.t5.T5OnnxConfig",
|
onnx_config_cls="models.t5.T5OnnxConfig",
|
||||||
),
|
),
|
||||||
|
"vision-encoder-decoder": supported_features_mapping(
|
||||||
|
"vision2seq-lm", onnx_config_cls="models.vision_encoder_decoder.VisionEncoderDecoderOnnxConfig"
|
||||||
|
),
|
||||||
"vit": supported_features_mapping(
|
"vit": supported_features_mapping(
|
||||||
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
|
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
|
||||||
),
|
),
|
||||||
@@ -582,6 +587,7 @@ class FeaturesManager:
|
|||||||
raise KeyError(
|
raise KeyError(
|
||||||
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
f"Unknown task: {feature}. Possible values are {list(FeaturesManager._TASKS_TO_AUTOMODELS.values())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return task_to_automodel[task]
|
return task_to_automodel[task]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -161,7 +161,6 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
"""
|
"""
|
||||||
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
|
||||||
with self.subTest(name):
|
with self.subTest(name):
|
||||||
|
|
||||||
# without past
|
# without past
|
||||||
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
onnx_config_default = OnnxConfigWithPast.from_model_config(config())
|
||||||
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
|
||||||
@@ -220,6 +219,10 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
("swin", "microsoft/swin-tiny-patch4-window7-224"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
|
||||||
|
("vision-encoder-decoder", "nlpconnect/vit-gpt2-image-captioning"),
|
||||||
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
("bloom", "bigscience/bloom-560m"),
|
("bloom", "bigscience/bloom-560m"),
|
||||||
("gpt2", "gpt2"),
|
("gpt2", "gpt2"),
|
||||||
@@ -347,6 +350,70 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
except (RuntimeError, ValueError) as e:
|
except (RuntimeError, ValueError) as e:
|
||||||
self.fail(f"{name}, {feature} -> {e}")
|
self.fail(f"{name}, {feature} -> {e}")
|
||||||
|
|
||||||
|
def _onnx_export_encoder_decoder_models(
|
||||||
|
self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"
|
||||||
|
):
|
||||||
|
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||||
|
from transformers.onnx import export
|
||||||
|
|
||||||
|
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
model = model_class.from_config(config)
|
||||||
|
|
||||||
|
onnx_config = onnx_config_class_constructor(model.config)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers.utils import torch_version
|
||||||
|
|
||||||
|
if torch_version < onnx_config.torch_onnx_minimum_version:
|
||||||
|
pytest.skip(
|
||||||
|
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||||
|
f" {onnx_config.torch_onnx_minimum_version}, got: {torch_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_model = model.get_encoder()
|
||||||
|
decoder_model = model.get_decoder()
|
||||||
|
|
||||||
|
encoder_onnx_config = onnx_config.get_encoder_config(encoder_model.config)
|
||||||
|
decoder_onnx_config = onnx_config.get_decoder_config(encoder_model.config, decoder_model.config, feature)
|
||||||
|
|
||||||
|
preprocessor = AutoFeatureExtractor.from_pretrained(model_name)
|
||||||
|
|
||||||
|
onnx_opset = max(encoder_onnx_config.default_onnx_opset, decoder_onnx_config.default_onnx_opset)
|
||||||
|
|
||||||
|
with NamedTemporaryFile("w") as encoder_output:
|
||||||
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
preprocessor, encoder_model, encoder_onnx_config, onnx_opset, Path(encoder_output.name), device=device
|
||||||
|
)
|
||||||
|
validate_model_outputs(
|
||||||
|
encoder_onnx_config,
|
||||||
|
preprocessor,
|
||||||
|
encoder_model,
|
||||||
|
Path(encoder_output.name),
|
||||||
|
onnx_outputs,
|
||||||
|
encoder_onnx_config.atol_for_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocessor = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
with NamedTemporaryFile("w") as decoder_output:
|
||||||
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
preprocessor,
|
||||||
|
decoder_model,
|
||||||
|
decoder_onnx_config,
|
||||||
|
onnx_config.default_onnx_opset,
|
||||||
|
Path(decoder_output.name),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
validate_model_outputs(
|
||||||
|
decoder_onnx_config,
|
||||||
|
preprocessor,
|
||||||
|
decoder_model,
|
||||||
|
Path(decoder_output.name),
|
||||||
|
onnx_outputs,
|
||||||
|
decoder_onnx_config.atol_for_validation,
|
||||||
|
)
|
||||||
|
|
||||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@@ -363,6 +430,28 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
|
||||||
|
|
||||||
|
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
@require_rjieba
|
||||||
|
def test_pytorch_export_encoder_decoder_models(
|
||||||
|
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||||
|
):
|
||||||
|
self._onnx_export_encoder_decoder_models(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
|
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_ENCODER_DECODER_MODELS))
|
||||||
|
@slow
|
||||||
|
@require_torch
|
||||||
|
@require_vision
|
||||||
|
@require_rjieba
|
||||||
|
def test_pytorch_export_encoder_decoder_models_on_cuda(
|
||||||
|
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||||
|
):
|
||||||
|
self._onnx_export_encoder_decoder_models(
|
||||||
|
test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
Reference in New Issue
Block a user