From 9d7b70bcd77a3a684c160b636faf48b4da4c1b72 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 3 Aug 2022 03:33:44 -0700 Subject: [PATCH] support ONNX export of XDropout in deberta{,_v2} and sew_d (#17502) * support ONNX export of XDropout in deberta{,_v2} * black * copy to sew_d * add test * isort * use pytest.mark.filterwarnings * review comments --- .../models/deberta/modeling_deberta.py | 15 +++++++ .../models/deberta_v2/modeling_deberta_v2.py | 15 +++++++ .../models/sew_d/modeling_sew_d.py | 15 +++++++ tests/onnx/test_onnx_v2.py | 43 +++++++++++++++++++ 4 files changed, 88 insertions(+) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 45121b23bf..2d9e647c13 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -185,6 +185,21 @@ class XDropout(torch.autograd.Function): else: return grad_output, None + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train) + class StableDropout(nn.Module): """ diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index 7d4a6f5c38..738981648a 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -191,6 +191,21 @@ class XDropout(torch.autograd.Function): else: return grad_output, None + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train) + # Copied from transformers.models.deberta.modeling_deberta.StableDropout class StableDropout(nn.Module): diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 8dc210d06c..e582705ab0 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -595,6 +595,21 @@ class XDropout(torch.autograd.Function): else: return grad_output, None + @staticmethod + def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value: + dropout_p = local_ctx + if isinstance(local_ctx, DropoutContext): + dropout_p = local_ctx.dropout + # StableDropout only calls this function when training. + train = True + # TODO: We should check if the opset_version being used to export + # is > 12 here, but there's no good way to do that. As-is, if the + # opset_version < 12, export will fail with a CheckerError. + # Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like: + # if opset_version < 12: + # return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train) + return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train) + # Copied from transformers.models.deberta.modeling_deberta.StableDropout class StableDropout(nn.Module): diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 6b22dc3420..c15910734f 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from tempfile import NamedTemporaryFile from unittest import TestCase @@ -26,6 +27,11 @@ from transformers.testing_utils import require_onnx, require_rjieba, require_tf, if is_torch_available() or is_tf_available(): from transformers.onnx.features import FeaturesManager +if is_torch_available(): + import torch + + from transformers.models.deberta import modeling_deberta + @require_onnx class OnnxUtilsTestCaseV2(TestCase): @@ -356,3 +362,40 @@ class OnnxExportTestCaseV2(TestCase): self, test_name, name, model_name, feature, onnx_config_class_constructor ): self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) + + +class StableDropoutTestCase(TestCase): + """Tests export of StableDropout module.""" + + @require_torch + @pytest.mark.filterwarnings("ignore:.*Dropout.*:UserWarning:torch.onnx.*") # torch.onnx is spammy. + def test_training(self): + """Tests export of StableDropout in training mode.""" + devnull = open(os.devnull, "wb") + # drop_prob must be > 0 for the test to be meaningful + sd = modeling_deberta.StableDropout(0.1) + # Avoid warnings in training mode + do_constant_folding = False + # Dropout is a no-op in inference mode + training = torch.onnx.TrainingMode.PRESERVE + input = (torch.randn(2, 2),) + + torch.onnx.export( + sd, + input, + devnull, + opset_version=12, # Minimum supported + do_constant_folding=do_constant_folding, + training=training, + ) + + # Expected to fail with opset_version < 12 + with self.assertRaises(Exception): + torch.onnx.export( + sd, + input, + devnull, + opset_version=11, + do_constant_folding=do_constant_folding, + training=training, + )