Fix issue introduced in PR #23163 (#23363)

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-05-15 11:38:44 +02:00
committed by GitHub
parent 2958b55fe5
commit 81a73fa638
2 changed files with 4 additions and 3 deletions

View File

@@ -234,7 +234,7 @@ class OnnxConfig(ABC):
if is_torch_available(): if is_torch_available():
from transformers.utils import get_torch_version from transformers.utils import get_torch_version
return get_torch_version() >= self.torch_onnx_minimum_version return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version
else: else:
return False return False

View File

@@ -6,6 +6,7 @@ from unittest import TestCase
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from packaging import version
from parameterized import parameterized from parameterized import parameterized
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
@@ -321,7 +322,7 @@ class OnnxExportTestCaseV2(TestCase):
if is_torch_available(): if is_torch_available():
from transformers.utils import get_torch_version from transformers.utils import get_torch_version
if get_torch_version() < onnx_config.torch_onnx_minimum_version: if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
pytest.skip( pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is" "Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
@@ -364,7 +365,7 @@ class OnnxExportTestCaseV2(TestCase):
if is_torch_available(): if is_torch_available():
from transformers.utils import get_torch_version from transformers.utils import get_torch_version
if get_torch_version() < onnx_config.torch_onnx_minimum_version: if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
pytest.skip( pytest.skip(
"Skipping due to incompatible PyTorch version. Minimum required is" "Skipping due to incompatible PyTorch version. Minimum required is"
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}" f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"