* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -234,7 +234,7 @@ class OnnxConfig(ABC):
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.utils import get_torch_version
|
from transformers.utils import get_torch_version
|
||||||
|
|
||||||
return get_torch_version() >= self.torch_onnx_minimum_version
|
return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from unittest import TestCase
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from packaging import version
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
|
from transformers import AutoConfig, PreTrainedTokenizerBase, is_tf_available, is_torch_available
|
||||||
@@ -321,7 +322,7 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.utils import get_torch_version
|
from transformers.utils import get_torch_version
|
||||||
|
|
||||||
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
|
if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||||
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
||||||
@@ -364,7 +365,7 @@ class OnnxExportTestCaseV2(TestCase):
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers.utils import get_torch_version
|
from transformers.utils import get_torch_version
|
||||||
|
|
||||||
if get_torch_version() < onnx_config.torch_onnx_minimum_version:
|
if version.parse(get_torch_version()) < onnx_config.torch_onnx_minimum_version:
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Skipping due to incompatible PyTorch version. Minimum required is"
|
"Skipping due to incompatible PyTorch version. Minimum required is"
|
||||||
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
f" {onnx_config.torch_onnx_minimum_version}, got: {get_torch_version()}"
|
||||||
|
|||||||
Reference in New Issue
Block a user