Add MBART to models exportable with ONNX (#13049)
* Add MBART to models exportable with ONNX * unittest mock * Add tests * Misc fixes
This commit is contained in:
@@ -28,7 +28,7 @@ from ...file_utils import (
|
|||||||
|
|
||||||
|
|
||||||
_import_structure = {
|
_import_structure = {
|
||||||
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig"],
|
"configuration_mbart": ["MBART_PRETRAINED_CONFIG_ARCHIVE_MAP", "MBartConfig", "MBartOnnxConfig"],
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
@@ -66,7 +66,7 @@ if is_flax_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
|
from .configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig, MBartOnnxConfig
|
||||||
|
|
||||||
if is_sentencepiece_available():
|
if is_sentencepiece_available():
|
||||||
from .tokenization_mbart import MBartTokenizer
|
from .tokenization_mbart import MBartTokenizer
|
||||||
|
|||||||
@@ -13,6 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" MBART model configuration """
|
""" MBART model configuration """
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Mapping
|
||||||
|
|
||||||
|
from transformers.onnx import OnnxConfigWithPast
|
||||||
|
|
||||||
from ...configuration_utils import PretrainedConfig
|
from ...configuration_utils import PretrainedConfig
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -171,3 +175,32 @@ class MBartConfig(PretrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def hidden_size(self) -> int:
|
def hidden_size(self) -> int:
|
||||||
return self.d_model
|
return self.d_model
|
||||||
|
|
||||||
|
|
||||||
|
class MBartOnnxConfig(OnnxConfigWithPast):
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", {0: "batch", 1: "sequence"}),
|
||||||
|
("attention_mask", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.use_past:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
("past_keys", {0: "batch", 2: "sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
("encoder_last_hidden_state", {0: "batch", 1: "sequence"}),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from ..models.distilbert import DistilBertOnnxConfig
|
|||||||
from ..models.gpt2 import GPT2OnnxConfig
|
from ..models.gpt2 import GPT2OnnxConfig
|
||||||
from ..models.gpt_neo import GPTNeoOnnxConfig
|
from ..models.gpt_neo import GPTNeoOnnxConfig
|
||||||
from ..models.longformer import LongformerOnnxConfig
|
from ..models.longformer import LongformerOnnxConfig
|
||||||
|
from ..models.mbart import MBartOnnxConfig
|
||||||
from ..models.roberta import RobertaOnnxConfig
|
from ..models.roberta import RobertaOnnxConfig
|
||||||
from ..models.t5 import T5OnnxConfig
|
from ..models.t5 import T5OnnxConfig
|
||||||
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
from ..models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
@@ -58,6 +59,7 @@ class FeaturesManager:
|
|||||||
_SUPPORTED_MODEL_KIND = {
|
_SUPPORTED_MODEL_KIND = {
|
||||||
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
|
"albert": supported_features_mapping("default", onnx_config_cls=AlbertOnnxConfig),
|
||||||
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
|
||||||
|
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
|
||||||
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
|
||||||
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
|
||||||
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from distutils.util import strtobool
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator, Union
|
from typing import Iterator, Union
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
@@ -1007,7 +1008,7 @@ def mockenv(**kwargs):
|
|||||||
use_tf = os.getenv("USE_TF", False)
|
use_tf = os.getenv("USE_TF", False)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return unittest.mock.patch.dict(os.environ, kwargs)
|
return mock.patch.dict(os.environ, kwargs)
|
||||||
|
|
||||||
|
|
||||||
# from https://stackoverflow.com/a/34333710/9201239
|
# from https://stackoverflow.com/a/34333710/9201239
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
|
|||||||
DistilBertConfig,
|
DistilBertConfig,
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPTNeoConfig,
|
GPTNeoConfig,
|
||||||
|
MBartConfig,
|
||||||
RobertaConfig,
|
RobertaConfig,
|
||||||
XLMRobertaConfig,
|
XLMRobertaConfig,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
|
|||||||
# from transformers.models.longformer import LongformerOnnxConfig
|
# from transformers.models.longformer import LongformerOnnxConfig
|
||||||
from transformers.models.gpt2 import GPT2OnnxConfig
|
from transformers.models.gpt2 import GPT2OnnxConfig
|
||||||
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
from transformers.models.gpt_neo import GPTNeoOnnxConfig
|
||||||
|
from transformers.models.mbart import MBartOnnxConfig
|
||||||
from transformers.models.roberta import RobertaOnnxConfig
|
from transformers.models.roberta import RobertaOnnxConfig
|
||||||
|
|
||||||
# from transformers.models.t5 import T5OnnxConfig
|
# from transformers.models.t5 import T5OnnxConfig
|
||||||
@@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
|
OnnxConfigWithPast.with_past(config()).use_past,
|
||||||
|
"OnnxConfigWithPast.from_model_config() should use_past",
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
|
||||||
@@ -190,6 +193,7 @@ if is_torch_available():
|
|||||||
DistilBertModel,
|
DistilBertModel,
|
||||||
GPT2Model,
|
GPT2Model,
|
||||||
GPTNeoModel,
|
GPTNeoModel,
|
||||||
|
MBartModel,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
XLMRobertaModel,
|
XLMRobertaModel,
|
||||||
)
|
)
|
||||||
@@ -204,6 +208,7 @@ if is_torch_available():
|
|||||||
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
||||||
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
||||||
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
||||||
|
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
|
||||||
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
|
|
||||||
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
|
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
|
||||||
with self.subTest(name):
|
with self.subTest(name):
|
||||||
self.assertTrue(hasattr(onnx_config_class, "default"))
|
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
model = model_class(config_class.from_pretrained(model))
|
model = model_class(config_class.from_pretrained(model))
|
||||||
onnx_config = onnx_config_class.default(model.config)
|
onnx_config = onnx_config_class.from_model_config(model.config)
|
||||||
|
|
||||||
with NamedTemporaryFile("w") as output:
|
with NamedTemporaryFile("w") as output:
|
||||||
onnx_inputs, onnx_outputs = export(
|
onnx_inputs, onnx_outputs = export(
|
||||||
|
|||||||
Reference in New Issue
Block a user