support ONNX export of XDropout in deberta{,_v2} and sew_d (#17502)
* support ONNX export of XDropout in deberta{,_v2}
* black
* copy to sew_d
* add test
* isort
* use pytest.mark.filterwarnings
* review comments
This commit is contained in:
@@ -185,6 +185,21 @@ class XDropout(torch.autograd.Function):
|
|||||||
else:
|
else:
|
||||||
return grad_output, None
|
return grad_output, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
||||||
|
dropout_p = local_ctx
|
||||||
|
if isinstance(local_ctx, DropoutContext):
|
||||||
|
dropout_p = local_ctx.dropout
|
||||||
|
# StableDropout only calls this function when training.
|
||||||
|
train = True
|
||||||
|
# TODO: We should check if the opset_version being used to export
|
||||||
|
# is > 12 here, but there's no good way to do that. As-is, if the
|
||||||
|
# opset_version < 12, export will fail with a CheckerError.
|
||||||
|
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
||||||
|
# if opset_version < 12:
|
||||||
|
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
||||||
|
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
|
||||||
|
|
||||||
|
|
||||||
class StableDropout(nn.Module):
|
class StableDropout(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -191,6 +191,21 @@ class XDropout(torch.autograd.Function):
|
|||||||
else:
|
else:
|
||||||
return grad_output, None
|
return grad_output, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
||||||
|
dropout_p = local_ctx
|
||||||
|
if isinstance(local_ctx, DropoutContext):
|
||||||
|
dropout_p = local_ctx.dropout
|
||||||
|
# StableDropout only calls this function when training.
|
||||||
|
train = True
|
||||||
|
# TODO: We should check if the opset_version being used to export
|
||||||
|
# is > 12 here, but there's no good way to do that. As-is, if the
|
||||||
|
# opset_version < 12, export will fail with a CheckerError.
|
||||||
|
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
||||||
|
# if opset_version < 12:
|
||||||
|
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
||||||
|
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
|
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
|
||||||
class StableDropout(nn.Module):
|
class StableDropout(nn.Module):
|
||||||
|
|||||||
@@ -595,6 +595,21 @@ class XDropout(torch.autograd.Function):
|
|||||||
else:
|
else:
|
||||||
return grad_output, None
|
return grad_output, None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g: torch._C.Graph, input: torch._C.Value, local_ctx: Union[float, DropoutContext]) -> torch._C.Value:
|
||||||
|
dropout_p = local_ctx
|
||||||
|
if isinstance(local_ctx, DropoutContext):
|
||||||
|
dropout_p = local_ctx.dropout
|
||||||
|
# StableDropout only calls this function when training.
|
||||||
|
train = True
|
||||||
|
# TODO: We should check if the opset_version being used to export
|
||||||
|
# is > 12 here, but there's no good way to do that. As-is, if the
|
||||||
|
# opset_version < 12, export will fail with a CheckerError.
|
||||||
|
# Once https://github.com/pytorch/pytorch/issues/78391 is fixed, do something like:
|
||||||
|
# if opset_version < 12:
|
||||||
|
# return torch.onnx.symbolic_opset9.dropout(g, input, dropout_p, train)
|
||||||
|
return torch.onnx.symbolic_opset12.dropout(g, input, dropout_p, train)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
|
# Copied from transformers.models.deberta.modeling_deberta.StableDropout
|
||||||
class StableDropout(nn.Module):
|
class StableDropout(nn.Module):
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
@@ -26,6 +27,11 @@ from transformers.testing_utils import require_onnx, require_rjieba, require_tf,
|
|||||||
if is_torch_available() or is_tf_available():
|
if is_torch_available() or is_tf_available():
|
||||||
from transformers.onnx.features import FeaturesManager
|
from transformers.onnx.features import FeaturesManager
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from transformers.models.deberta import modeling_deberta
|
||||||
|
|
||||||
|
|
||||||
@require_onnx
|
@require_onnx
|
||||||
class OnnxUtilsTestCaseV2(TestCase):
|
class OnnxUtilsTestCaseV2(TestCase):
|
||||||
@@ -356,3 +362,40 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
self, test_name, name, model_name, feature, onnx_config_class_constructor
|
||||||
):
|
):
|
||||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDropoutTestCase(TestCase):
|
||||||
|
"""Tests export of StableDropout module."""
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@pytest.mark.filterwarnings("ignore:.*Dropout.*:UserWarning:torch.onnx.*") # torch.onnx is spammy.
|
||||||
|
def test_training(self):
|
||||||
|
"""Tests export of StableDropout in training mode."""
|
||||||
|
devnull = open(os.devnull, "wb")
|
||||||
|
# drop_prob must be > 0 for the test to be meaningful
|
||||||
|
sd = modeling_deberta.StableDropout(0.1)
|
||||||
|
# Avoid warnings in training mode
|
||||||
|
do_constant_folding = False
|
||||||
|
# Dropout is a no-op in inference mode
|
||||||
|
training = torch.onnx.TrainingMode.PRESERVE
|
||||||
|
input = (torch.randn(2, 2),)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
sd,
|
||||||
|
input,
|
||||||
|
devnull,
|
||||||
|
opset_version=12, # Minimum supported
|
||||||
|
do_constant_folding=do_constant_folding,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected to fail with opset_version < 12
|
||||||
|
with self.assertRaises(Exception):
|
||||||
|
torch.onnx.export(
|
||||||
|
sd,
|
||||||
|
input,
|
||||||
|
devnull,
|
||||||
|
opset_version=11,
|
||||||
|
do_constant_folding=do_constant_folding,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user