From 6da76b9c2ac3880dcd573d7051e0b0b00cd6c7f6 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Wed, 18 May 2022 17:52:13 +0200 Subject: [PATCH] Add onnx export cuda support (#17183) Co-authored-by: Lysandre Debut Co-authored-by: lewtun --- .../models/big_bird/modeling_big_bird.py | 2 +- src/transformers/onnx/convert.py | 18 +++++++++++++++++- tests/onnx/test_onnx_v2.py | 12 ++++++++++-- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 3c41c457bd..070831db4d 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -3099,7 +3099,7 @@ class BigBirdForQuestionAnswering(BigBirdPreTrainedModel): # setting lengths logits to `-inf` logits_mask = self.prepare_question_mask(question_lengths, seqlen) if token_type_ids is None: - token_type_ids = torch.ones(logits_mask.size(), dtype=int) - logits_mask + token_type_ids = torch.ones(logits_mask.size(), dtype=int, device=logits_mask.device) - logits_mask logits_mask = logits_mask logits_mask[:, 0] = False logits_mask.unsqueeze_(2) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 2f1789bbdc..43224532e6 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -86,6 +86,7 @@ def export_pytorch( opset: int, output: Path, tokenizer: "PreTrainedTokenizer" = None, + device: str = "cpu", ) -> Tuple[List[str], List[str]]: """ Export a PyTorch model to an ONNX Intermediate Representation (IR) @@ -101,6 +102,8 @@ def export_pytorch( The version of the ONNX operator set to use. output (`Path`): Directory to store the exported ONNX model. + device (`str`, *optional*, defaults to `cpu`): + The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from @@ -137,6 +140,10 @@ def export_pytorch( # Ensure inputs match # TODO: Check when exporting QA we provide "is_pair=True" model_inputs = config.generate_dummy_inputs(preprocessor, framework=TensorType.PYTORCH) + device = torch.device(device) + if device.type == "cuda" and torch.cuda.is_available(): + model.to(device) + model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items()) inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) onnx_outputs = list(config.outputs.keys()) @@ -268,6 +275,7 @@ def export( opset: int, output: Path, tokenizer: "PreTrainedTokenizer" = None, + device: str = "cpu", ) -> Tuple[List[str], List[str]]: """ Export a Pytorch or TensorFlow model to an ONNX Intermediate Representation (IR) @@ -283,6 +291,9 @@ def export( The version of the ONNX operator set to use. output (`Path`): Directory to store the exported ONNX model. + device (`str`, *optional*, defaults to `cpu`): + The device on which the ONNX model will be exported. Either `cpu` or `cuda`. Only PyTorch is supported for + export on CUDA devices. Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from @@ -294,6 +305,9 @@ def export( "Please install torch or tensorflow first." ) + if is_tf_available() and isinstance(model, TFPreTrainedModel) and device == "cuda": + raise RuntimeError("`tf2onnx` does not support export on CUDA device.") + if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None: raise ValueError("You cannot provide both a tokenizer and a preprocessor to export the model.") if tokenizer is not None: @@ -318,7 +332,7 @@ def export( ) if is_torch_available() and issubclass(type(model), PreTrainedModel): - return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer) + return export_pytorch(preprocessor, model, config, opset, output, tokenizer=tokenizer, device=device) elif is_tf_available() and issubclass(type(model), TFPreTrainedModel): return export_tensorflow(preprocessor, model, config, opset, output, tokenizer=tokenizer) @@ -359,6 +373,8 @@ def validate_model_outputs( session = InferenceSession(onnx_model.as_posix(), options, providers=["CPUExecutionProvider"]) # Compute outputs from the reference model + if is_torch_available() and issubclass(type(reference_model), PreTrainedModel): + reference_model.to("cpu") ref_outputs = reference_model(**reference_model_inputs) ref_outputs_dict = {} diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index eb234e9896..5ebef03873 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -242,7 +242,7 @@ class OnnxExportTestCaseV2(TestCase): Integration tests ensuring supported models are correctly exported """ - def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): + def _onnx_export(self, test_name, name, model_name, feature, onnx_config_class_constructor, device="cpu"): from transformers.onnx import export model_class = FeaturesManager.get_model_class_for_feature(feature) @@ -273,7 +273,7 @@ class OnnxExportTestCaseV2(TestCase): with NamedTemporaryFile("w") as output: try: onnx_inputs, onnx_outputs = export( - preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name) + preprocessor, model, onnx_config, onnx_config.default_onnx_opset, Path(output.name), device=device ) validate_model_outputs( onnx_config, @@ -294,6 +294,14 @@ class OnnxExportTestCaseV2(TestCase): def test_pytorch_export(self, test_name, name, model_name, feature, onnx_config_class_constructor): self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS)) + @slow + @require_torch + @require_vision + @require_rjieba + def test_pytorch_export_on_cuda(self, test_name, name, model_name, feature, onnx_config_class_constructor): + self._onnx_export(test_name, name, model_name, feature, onnx_config_class_constructor, device="cuda") + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_WITH_PAST_MODELS)) @slow @require_torch