Make T5 compatible with ONNX (#5518)

* Default decoder inputs to encoder ones for T5 if neither are specified.

* Fixing typo, now all tests are passing.

* Changing einsum to operations supported by onnx

* Adding a test to ensure T5 can be exported to onnx op>9

* Modified test for onnx export to make it faster

* Styling changes.

* Styling changes.

* Changing notation for matrix multiplication

Co-authored-by: Abel Riboulot <tkai@protomail.com>
This commit is contained in:
Abel
2020-07-07 11:32:29 +02:00
committed by GitHub
parent 989ae326b5
commit 6912265711
2 changed files with 26 additions and 5 deletions

View File

@@ -351,6 +351,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model = T5Model.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_export_to_onnx(self):
import tempfile
config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = T5Model(config_and_inputs[0])
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
)
@require_torch
class T5ModelIntegrationTests(unittest.TestCase):