ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933)
* Raise an issue if the pytorch version is < 1.8.0
* Attempt to add a test to ensure it correctly raises.
* Missing docstring.
* Second attempt, patch with string absolute import.
* Let's do the call before checking it was called ...
* use the correct function ... 🤦
* Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient.
* Correct path mock patching
* relax constraint for torch_onnx_dict_inputs to ge instead of eq.
* Style.
* Split each version requirements for torch.
* Let's compare version directly.
* Import torch_version after checking pytorch is installed.
* @require_torch
This commit is contained in:
@@ -274,8 +274,9 @@ PRESET_MIRROR_DICT = {
|
|||||||
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
|
||||||
}
|
}
|
||||||
|
|
||||||
# This is the version of torch required to run torch.fx features.
|
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
|
||||||
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
|
TORCH_FX_REQUIRED_VERSION = version.parse("1.8")
|
||||||
|
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
|
||||||
|
|
||||||
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
|
||||||
|
|
||||||
@@ -297,7 +298,7 @@ def is_torch_cuda_available():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
_torch_fx_available = False
|
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||||
if _torch_available:
|
if _torch_available:
|
||||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||||
_torch_fx_available = (torch_version.major, torch_version.minor) == (
|
_torch_fx_available = (torch_version.major, torch_version.minor) == (
|
||||||
@@ -305,11 +306,17 @@ if _torch_available:
|
|||||||
TORCH_FX_REQUIRED_VERSION.minor,
|
TORCH_FX_REQUIRED_VERSION.minor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_torch_onnx_dict_inputs_support_available = torch_version >= TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION
|
||||||
|
|
||||||
|
|
||||||
def is_torch_fx_available():
|
def is_torch_fx_available():
|
||||||
return _torch_fx_available
|
return _torch_fx_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_onnx_dict_inputs_support_available():
|
||||||
|
return _torch_onnx_dict_inputs_support_available
|
||||||
|
|
||||||
|
|
||||||
def is_tf_available():
|
def is_tf_available():
|
||||||
return _tf_available
|
return _tf_available
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import numpy as np
|
|||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
|
||||||
|
from ..file_utils import is_torch_onnx_dict_inputs_support_available
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .config import OnnxConfig
|
from .config import OnnxConfig
|
||||||
from .utils import flatten_output_collection_property
|
from .utils import flatten_output_collection_property
|
||||||
@@ -79,11 +80,16 @@ def export(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
|
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
|
from ..file_utils import torch_version
|
||||||
|
|
||||||
|
if not is_torch_onnx_dict_inputs_support_available():
|
||||||
|
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
|
||||||
|
|
||||||
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
logger.info(f"Using framework PyTorch: {torch.__version__}")
|
||||||
torch.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
model.config.return_dict = True
|
model.config.return_dict = True
|
||||||
|
|||||||
@@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig
|
|||||||
|
|
||||||
# from transformers.models.t5 import T5OnnxConfig
|
# from transformers.models.t5 import T5OnnxConfig
|
||||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||||
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
|
from transformers.onnx import (
|
||||||
|
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||||
|
OnnxConfig,
|
||||||
|
ParameterFormat,
|
||||||
|
export,
|
||||||
|
validate_model_outputs,
|
||||||
|
)
|
||||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||||
from transformers.onnx.utils import (
|
from transformers.onnx.utils import (
|
||||||
compute_effective_axis_dimension,
|
compute_effective_axis_dimension,
|
||||||
@@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
|
|||||||
Cover all the utilities involved to export ONNX models
|
Cover all the utilities involved to export ONNX models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
|
||||||
|
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
|
||||||
|
"""
|
||||||
|
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
|
||||||
|
"""
|
||||||
|
self.assertRaises(AssertionError, export, None, None, None, None, None)
|
||||||
|
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
|
||||||
|
|
||||||
def test_compute_effective_axis_dimension(self):
|
def test_compute_effective_axis_dimension(self):
|
||||||
"""
|
"""
|
||||||
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
|
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
|
||||||
|
|||||||
Reference in New Issue
Block a user