From 6f5ab9daf1915d13b911231d9e05576d317c4e16 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Mon, 9 Aug 2021 14:56:04 +0200 Subject: [PATCH] Add MBART to models exportable with ONNX (#13049) * Add MBART to models exportable with ONNX * unittest mock * Add tests * Misc fixes --- src/transformers/models/mbart/__init__.py | 4 +-- .../models/mbart/configuration_mbart.py | 33 +++++++++++++++++++ src/transformers/onnx/features.py | 2 ++ src/transformers/testing_utils.py | 3 +- tests/test_onnx_v2.py | 11 +++++-- 5 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index f17ab91b27..613c90afbe 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -28,7 +28,7 @@ from ...file_utils import ( _import_structure = { - "configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"], + "configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"], } if is_sentencepiece_available(): @@ -66,7 +66,7 @@ if is_flax_available(): if TYPE_CHECKING: - from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig + from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig if is_sentencepiece_available(): from .tokenization_mbart import MBartTokenizer diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index d8f8364850..610ebf46cb 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ MBART model configuration """ +from collections import OrderedDict +from typing import Mapping + +from transformers.onnx import OnnxConfigWithPast from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig): @property def hidden_size(self) -> int: return self.d_model + + +class MBartOnnxConfig(OnnxConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("input_ids", {0: "batch", 1: "sequence"}), + ("attention_mask", {0: "batch", 1: "sequence"}), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.use_past: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("past_keys", {0: "batch", 2: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) + else: + return OrderedDict( + [ + ("last_hidden_state", {0: "batch", 1: "sequence"}), + ("encoder_last_hidden_state", {0: "batch", 1: "sequence"}), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 530f68b20a..73f7df359c 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -9,6 +9,7 @@ from ..models.distilbert import DistilBertOnnxConfig from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.longformer import LongformerOnnxConfig +from ..models.mbart import MBartOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.t5 import T5OnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig @@ -58,6 +59,7 @@ class FeaturesManager: _SUPPORTED_MODEL_KIND = { "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), + "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig), "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8fa904a4e6..4ac25179ca 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -25,6 +25,7 @@ from distutils.util import strtobool from io import StringIO from pathlib import Path from typing import Iterator, Union +from unittest import mock from transformers import logging as transformers_logging @@ -1007,7 +1008,7 @@ def mockenv(**kwargs): use_tf = os.getenv("USE_TF", False) """ - return unittest.mock.patch.dict(os.environ, kwargs) + return mock.patch.dict(os.environ, kwargs) # from https://stackoverflow.com/a/34333710/9201239 diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index 3e99162d80..2c07b0abaa 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config, DistilBertConfig, GPT2Config, GPTNeoConfig, + MBartConfig, RobertaConfig, XLMRobertaConfig, is_torch_available, @@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig # from transformers.models.longformer import LongformerOnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.gpt_neo import GPTNeoOnnxConfig +from transformers.models.mbart import MBartOnnxConfig from transformers.models.roberta import RobertaOnnxConfig # from transformers.models.t5 import T5OnnxConfig @@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ) self.assertTrue( - OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past" + OnnxConfigWithPast.with_past(config()).use_past, + "OnnxConfigWithPast.from_model_config() should use_past", ) @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) @@ -190,6 +193,7 @@ if is_torch_available(): DistilBertModel, GPT2Model, GPTNeoModel, + MBartModel, RobertaModel, XLMRobertaModel, ) @@ -204,6 +208,7 @@ if is_torch_available(): # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), + ("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig), # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), } @@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase): for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: with self.subTest(name): - self.assertTrue(hasattr(onnx_config_class, "default")) + self.assertTrue(hasattr(onnx_config_class, "from_model_config")) tokenizer = AutoTokenizer.from_pretrained(model) model = model_class(config_class.from_pretrained(model)) - onnx_config = onnx_config_class.default(model.config) + onnx_config = onnx_config_class.from_model_config(model.config) with NamedTemporaryFile("w") as output: onnx_inputs, onnx_outputs = export(