From 81a73fa638adf8a3768b37f3080ddbd6cc07418a Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Mon, 15 May 2023 11:38:44 +0200 Subject: [PATCH] Fix issue introduced in PR #23163 (#23363) * fix * fix --------- Co-authored-by: ydshieh --- src/transformers/onnx/config.py | 2 +- tests/onnx/test_onnx_v2.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 66236e9864..02bf2421f4 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -234,7 +234,7 @@ class OnnxConfig(ABC): if is_torch_available(): from transformers.utils import get_torch_version - return get_torch_version() >= self.torch_onnx_minimum_version + return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version else: return False diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 796fa1b3ea..e160cd77f9 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -6,6 +6,7 @@ from unittest import TestCase from unittest.mock import patch import pytest +from packaging import version from parameterized import parameterized from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available @@ -321,7 +322,7 @@ class OnnxExportTestCaseV2(TestCase): if is_torch_available(): from transformers.utils import get_torch_version - if get_torch_version() < onnx_config.torch_onnx_minimum_version: + if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version: pytest.skip( "Skipping due to incompatible PyTorch version. Minimum required is" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}" @@ -364,7 +365,7 @@ class OnnxExportTestCaseV2(TestCase): if is_torch_available(): from transformers.utils import get_torch_version - if get_torch_version() < onnx_config.torch_onnx_minimum_version: + if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version: pytest.skip( "Skipping due to incompatible PyTorch version. Minimum required is" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"