Add MBART to models exportable with ONNX (#13049)

* Add MBART to models exportable with ONNX

* unittest mock

* Add tests

* Misc fixes
This commit is contained in:
Lysandre Debut
2021-08-09 14:56:04 +02:00
committed by GitHub
parent 13a9c9a354
commit 6f5ab9daf1
5 changed files with 47 additions and 6 deletions

View File

@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = { _import_structure = {
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"], "configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
} }
if is_sentencepiece_available(): if is_sentencepiece_available():
@@ -66,7 +66,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_mbart import MBartTokenizer from .tokenization_mbart import MBartTokenizer

View File

@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" MBART model configuration """ """ MBART model configuration """
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfigWithPast
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
@@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
class MBartOnnxConfig(OnnxConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
if self.use_past:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("past_keys", {0: "batch", 2: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)
else:
return OrderedDict(
[
("last_hidden_state", {0: "batch", 1: "sequence"}),
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
]
)

View File

@@ -9,6 +9,7 @@ from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.longformer import LongformerOnnxConfig from ..models.longformer import LongformerOnnxConfig
from ..models.mbart import MBartOnnxConfig
from ..models.roberta import RobertaOnnxConfig from ..models.roberta import RobertaOnnxConfig
from ..models.t5 import T5OnnxConfig from ..models.t5 import T5OnnxConfig
from ..models.xlm_roberta import XLMRobertaOnnxConfig from ..models.xlm_roberta import XLMRobertaOnnxConfig
@@ -58,6 +59,7 @@ class FeaturesManager:
_SUPPORTED_MODEL_KIND = { _SUPPORTED_MODEL_KIND = {
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig), "albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),

View File

@@ -25,6 +25,7 @@ from distutils.util import strtobool
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Iterator, Union from typing import Iterator, Union
from unittest import mock
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
@@ -1007,7 +1008,7 @@ def mockenv(**kwargs):
use_tf = os.getenv("USE_TF", False) use_tf = os.getenv("USE_TF", False)
""" """
return unittest.mock.patch.dict(os.environ, kwargs) return mock.patch.dict(os.environ, kwargs)
# from https://stackoverflow.com/a/34333710/9201239 # from https://stackoverflow.com/a/34333710/9201239

View File

@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig, DistilBertConfig,
GPT2Config, GPT2Config,
GPTNeoConfig, GPTNeoConfig,
MBartConfig,
RobertaConfig, RobertaConfig,
XLMRobertaConfig, XLMRobertaConfig,
is_torch_available, is_torch_available,
@@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig # from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig # from transformers.models.t5 import T5OnnxConfig
@@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
) )
self.assertTrue( self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past" OnnxConfigWithPast.with_past(config()).use_past,
"OnnxConfigWithPast.from_model_config() should use_past",
) )
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set()) @patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
@@ -190,6 +193,7 @@ if is_torch_available():
DistilBertModel, DistilBertModel,
GPT2Model, GPT2Model,
GPTNeoModel, GPTNeoModel,
MBartModel,
RobertaModel, RobertaModel,
XLMRobertaModel, XLMRobertaModel,
) )
@@ -204,6 +208,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig), # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
} }
@@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name): with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default")) self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class.from_pretrained(model)) model = model_class(config_class.from_pretrained(model))
onnx_config = onnx_config_class.default(model.config) onnx_config = onnx_config_class.from_model_config(model.config)
with NamedTemporaryFile("w") as output: with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export( onnx_inputs, onnx_outputs = export(