ONNX v2 raises an Exception when using PyTorch < 1.8.0 (#12933)
* Raise an issue if the pytorch version is < 1.8.0
* Attempt to add a test to ensure it correctly raises.
* Missing docstring.
* Second attempt, patch with string absolute import.
* Let's do the call before checking it was called ...
* use the correct function ... 🤦
* Raise ImportError and AssertionError respectively when unable to find torch and torch version is not sufficient.
* Correct path mock patching
* relax constraint for torch_onnx_dict_inputs to ge instead of eq.
* Style.
* Split each version requirements for torch.
* Let's compare version directly.
* Import torch_version after checking pytorch is installed.
* @require_torch
This commit is contained in:
@@ -24,7 +24,13 @@ from transformers.models.roberta import RobertaOnnxConfig
|
||||
|
||||
# from transformers.models.t5 import T5OnnxConfig
|
||||
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
||||
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
|
||||
from transformers.onnx import (
|
||||
EXTERNAL_DATA_FORMAT_SIZE_LIMIT,
|
||||
OnnxConfig,
|
||||
ParameterFormat,
|
||||
export,
|
||||
validate_model_outputs,
|
||||
)
|
||||
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
|
||||
from transformers.onnx.utils import (
|
||||
compute_effective_axis_dimension,
|
||||
@@ -40,6 +46,15 @@ class OnnxUtilsTestCaseV2(TestCase):
|
||||
Cover all the utilities involved to export ONNX models
|
||||
"""
|
||||
|
||||
@require_torch
|
||||
@patch("transformers.onnx.convert.is_torch_onnx_dict_inputs_support_available", return_value=False)
|
||||
def test_ensure_pytorch_version_ge_1_8_0(self, mock_is_torch_onnx_dict_inputs_support_available):
|
||||
"""
|
||||
Ensure we raise an Exception if the pytorch version is unsupported (< 1.8.0)
|
||||
"""
|
||||
self.assertRaises(AssertionError, export, None, None, None, None, None)
|
||||
mock_is_torch_onnx_dict_inputs_support_available.assert_called()
|
||||
|
||||
def test_compute_effective_axis_dimension(self):
|
||||
"""
|
||||
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
|
||||
|
||||
Reference in New Issue
Block a user