Add MBART to models exportable with ONNX (#13049)

* Add MBART to models exportable with ONNX

* unittest mock

* Add tests

* Misc fixes
This commit is contained in:
Lysandre Debut
2021-08-09 14:56:04 +02:00
committed by GitHub
parent 13a9c9a354
commit 6f5ab9daf1
5 changed files with 47 additions and 6 deletions

View File

@@ -10,6 +10,7 @@ from transformers import ( # LongformerConfig,; T5Config,
DistilBertConfig,
GPT2Config,
GPTNeoConfig,
MBartConfig,
RobertaConfig,
XLMRobertaConfig,
is_torch_available,
@@ -22,6 +23,7 @@ from transformers.models.distilbert import DistilBertOnnxConfig
# from transformers.models.longformer import LongformerOnnxConfig
from transformers.models.gpt2 import GPT2OnnxConfig
from transformers.models.gpt_neo import GPTNeoOnnxConfig
from transformers.models.mbart import MBartOnnxConfig
from transformers.models.roberta import RobertaOnnxConfig
# from transformers.models.t5 import T5OnnxConfig
@@ -154,7 +156,8 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
)
self.assertTrue(
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
OnnxConfigWithPast.with_past(config()).use_past,
"OnnxConfigWithPast.from_model_config() should use_past",
)
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
@@ -190,6 +193,7 @@ if is_torch_available():
DistilBertModel,
GPT2Model,
GPTNeoModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
)
@@ -204,6 +208,7 @@ if is_torch_available():
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
("MBart", "sshleifer/tiny-mbart", MBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}
@@ -226,11 +231,11 @@ class OnnxExportTestCaseV2(TestCase):
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "default"))
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))
tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class.from_pretrained(model))
onnx_config = onnx_config_class.default(model.config)
onnx_config = onnx_config_class.from_model_config(model.config)
with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(