From 8cb5ecd912e09301be126c6ce6e9a22ca7153da4 Mon Sep 17 00:00:00 2001 From: Thomas Chaigneau Date: Tue, 9 Aug 2022 09:46:53 +0200 Subject: [PATCH] Add mt5 onnx config (#18394) * update features * MT5OnnxConfig added with updated with tests and docs * fix imports * fix onnc_config_cls for mt5 Co-authored-by: Thomas Chaigneau --- docs/source/en/serialization.mdx | 1 + src/transformers/models/mt5/__init__.py | 4 +-- .../models/mt5/configuration_mt5.py | 28 +++++++++++++++++++ src/transformers/onnx/features.py | 7 +++++ tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index e41ccae949..9561bbd8ec 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -79,6 +79,7 @@ Ready-made configurations include the following architectures: - mBART - MobileBERT - MobileViT +- MT5 - OpenAI GPT-2 - Perceiver - PLBart diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index 3f04a25691..f6e717bd87 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -43,7 +43,7 @@ else: MT5TokenizerFast = T5TokenizerFast -_import_structure = {"configuration_mt5": ["MT5Config"]} +_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]} try: if not is_torch_available(): @@ -71,7 +71,7 @@ else: if TYPE_CHECKING: - from .configuration_mt5 import MT5Config + from .configuration_mt5 import MT5Config, MT5OnnxConfig try: if not is_torch_available(): diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index ad0345f531..3e72831ad2 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ mT5 model configuration""" +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast from ...utils import logging @@ -143,3 +145,29 @@ class MT5Config(PretrainedConfig): @property def num_hidden_layers(self): return self.num_layers + + +# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig +class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 3eea94c8c1..8d8b8190e4 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -383,6 +383,13 @@ class FeaturesManager: "image-classification", onnx_config_cls="models.mobilevit.MobileViTOnnxConfig", ), + "mt5": supported_features_mapping( + "default", + "default-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + onnx_config_cls="models.mt5.MT5OnnxConfig", + ), "m2m-100": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index c22406841a..98ab0fad13 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -224,6 +224,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ("mbart", "sshleifer/tiny-mbart"), ("t5", "t5-small"), ("marian", "Helsinki-NLP/opus-mt-en-de"), + ("mt5", "google/mt5-base"), ("m2m-100", "facebook/m2m100_418M"), ("blenderbot-small", "facebook/blenderbot_small-90M"), ("blenderbot", "facebook/blenderbot-400M-distill"),