Make T5 compatible with ONNX (#5518)
* Default decoder inputs to encoder ones for T5 if neither are specified. * Fixing typo, now all tests are passing. * Changing einsum to operations supported by onnx * Adding a test to ensure T5 can be exported to onnx op>9 * Modified test for onnx export to make it faster * Styling changes. * Styling changes. * Changing notation for matrix multiplication Co-authored-by: Abel Riboulot <tkai@protomail.com>
This commit is contained in:
@@ -351,6 +351,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model = T5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_export_to_onnx(self):
|
||||
import tempfile
|
||||
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
model = T5Model(config_and_inputs[0])
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
torch.onnx.export(
|
||||
model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class T5ModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user