From 640421c0ec38a4d5733f58acbc3e4a8284f5eb9a Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Thu, 29 Jul 2021 18:02:29 +0200 Subject: [PATCH] ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933) * Raise an issue if the pytorch version is < 1.8.0 * Attempt to add a test to ensure it correctly raises. * Missing docstring. * Second attempt, patch with string absolute import. * Let's do the call before checking it was called ... * use the correct function ... :facepalm: * Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient. * Correct path mock patching * relax constraint for torch_onnx_dict_inputs to ge instead of eq. * Style. * Split each version requirements for torch. * Let's compare version directly. * Import torch_version after checking pytorch is installed. * @require_torch --- src/transformers/file_utils.py | 11 +++++++++-- src/transformers/onnx/convert.py | 8 +++++++- tests/test_onnx_v2.py | 17 ++++++++++++++++- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index bc2b29354e..2d40c5edd4 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = { "bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models", } -# This is the version of torch required to run torch.fx features. +# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. TORCH_FX_REQUIRED_VERSION = version.parse("1.8") +TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False @@ -297,7 +298,7 @@ def is_torch_cuda_available(): return False -_torch_fx_available = False +_torch_fx_available = _torch_onnx_dict_inputs_support_available = False if _torch_available: torch_version = version.parse(importlib_metadata.version("torch")) _torch_fx_available = (torch_version.major, torch_version.minor) == ( @@ -305,11 +306,17 @@ if _torch_available: TORCH_FX_REQUIRED_VERSION.minor, ) + _torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION + def is_torch_fx_available(): return _torch_fx_available +def is_torch_onnx_dict_inputs_support_available(): + return _torch_onnx_dict_inputs_support_available + + def is_tf_available(): return _tf_available diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 651e52b9a2..cfb600dfa4 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -21,6 +21,7 @@ import numpy as np from packaging.version import Version, parse from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available +from ..file_utils import is_torch_onnx_dict_inputs_support_available from ..utils import logging from .config import OnnxConfig from .utils import flatten_output_collection_property @@ -79,11 +80,16 @@ def export( """ if not is_torch_available(): - raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.") + raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.") import torch from torch.onnx import export + from ..file_utils import torch_version + + if not is_torch_onnx_dict_inputs_support_available(): + raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") + logger.info(f"Using framework PyTorch: {torch.__version__}") torch.set_grad_enabled(False) model.config.return_dict = True diff --git a/tests/test_onnx_v2.py b/tests/test_onnx_v2.py index cea2954c68..9493e8e066 100644 --- a/tests/test_onnx_v2.py +++ b/tests/test_onnx_v2.py @@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig # from transformers.models.t5 import T5OnnxConfig from transformers.models.xlm_roberta import XLMRobertaOnnxConfig -from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs +from transformers.onnx import ( + EXTERNAL_DATA_FORMAT_SIZE_LIMIT, + OnnxConfig, + ParameterFormat, + export, + validate_model_outputs, +) from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast from transformers.onnx.utils import ( compute_effective_axis_dimension, @@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase): Cover all the utilities involved to export ONNX models """ + @require_torch + @patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False) + def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available): + """ + Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0) + """ + self.assertRaises(AssertionError, export, None, None, None, None, None) + mock_is_torch_onnx_dict_inputs_support_available.assert_called() + def test_compute_effective_axis_dimension(self): """ When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.