Add onnx export cuda support (#17183)
Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
Integration tests ensuring supported models are correctly exported
|
||||
"""
|
||||
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"):
|
||||
from transformers.onnx import export
|
||||
|
||||
model_class = FeaturesManager.get_model_class_for_feature(feature)
|
||||
@@ -273,7 +273,7 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
with NamedTemporaryFile("w") as output:
|
||||
try:
|
||||
onnx_inputs, onnx_outputs = export(
|
||||
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name)
|
||||
preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device
|
||||
)
|
||||
validate_model_outputs(
|
||||
onnx_config,
|
||||
@@ -294,6 +294,14 @@ class OnnxExportTestCaseV2(TestCase):
|
||||
def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor)
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
@require_vision
|
||||
@require_rjieba
|
||||
def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor):
|
||||
self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda")
|
||||
|
||||
@parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS))
|
||||
@slow
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user