Add onnx export cuda support (#17183)

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Jingya HUANG
2022-05-18 17:52:13 +02:00
committed by GitHub
parent adc0ff2502
commit 6da76b9c2a
3 changed files with 28 additions and 4 deletions

View File

@@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase):
Integration tests ensuring supported models are correctly exported
"""
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
from transformers.onnx import export
model_class = FeaturesManager.get_model_class_for_feature(feature)
@@ -273,7 +273,7 @@ class OnnxExportTestCaseV2(TestCase):
with NamedTemporaryFile("w") as output:
try:
onnx_inputs, onnx_outputs = export(
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
)
validate_model_outputs(
onnx_config,
@@ -294,6 +294,14 @@ class OnnxExportTestCaseV2(TestCase):
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
@slow
@require_torch
@require_vision
@require_rjieba
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
@slow
@require_torch