[versions] handle version requirement ranges (#11110)
* handle version requirement ranges * add mixed requirement test * cleanup
This commit is contained in:
@@ -40,6 +40,17 @@ ops = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
|
||||||
|
if got_ver is None:
|
||||||
|
raise ValueError("got_ver is None")
|
||||||
|
if want_ver is None:
|
||||||
|
raise ValueError("want_ver is None")
|
||||||
|
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||||
|
raise ImportError(
|
||||||
|
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
||||||
@@ -51,33 +62,36 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
|||||||
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
|
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# note: while pkg_resources.require_version(requirement) is a much simpler way to do it, it
|
|
||||||
# fails if some of the dependencies of the dependencies are not matching, which is not necessarily
|
|
||||||
# bad, hence the more complicated check - which also should be faster, since it doesn't check
|
|
||||||
# dependencies of dependencies.
|
|
||||||
|
|
||||||
hint = f"\n{hint}" if hint is not None else ""
|
hint = f"\n{hint}" if hint is not None else ""
|
||||||
|
|
||||||
# non-versioned check
|
# non-versioned check
|
||||||
if re.match(r"^[\w_\-\d]+$", requirement):
|
if re.match(r"^[\w_\-\d]+$", requirement):
|
||||||
pkg, op, want_ver = requirement, None, None
|
pkg, op, want_ver = requirement, None, None
|
||||||
else:
|
else:
|
||||||
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2})(.+)", requirement)
|
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
|
||||||
if not match:
|
if not match:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
|
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
|
||||||
)
|
)
|
||||||
pkg, op, want_ver = match[0]
|
pkg, want_full = match[0]
|
||||||
|
want_range = want_full.split(",") # there could be multiple requirements
|
||||||
|
wanted = {}
|
||||||
|
for w in want_range:
|
||||||
|
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(
|
||||||
|
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
|
||||||
|
)
|
||||||
|
op, want_ver = match[0]
|
||||||
|
wanted[op] = want_ver
|
||||||
if op not in ops:
|
if op not in ops:
|
||||||
raise ValueError(f"need one of {list(ops.keys())}, but got {op}")
|
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
|
||||||
|
|
||||||
# special case
|
# special case
|
||||||
if pkg == "python":
|
if pkg == "python":
|
||||||
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||||
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
for op, want_ver in wanted.items():
|
||||||
raise ImportError(
|
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
||||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# check if any version is installed
|
# check if any version is installed
|
||||||
@@ -88,11 +102,10 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
|||||||
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
|
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# check that the right version is installed if version number was provided
|
# check that the right version is installed if version number or a range was provided
|
||||||
if want_ver is not None and not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
if want_ver is not None:
|
||||||
raise ImportError(
|
for op, want_ver in wanted.items():
|
||||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def require_version_core(requirement):
|
def require_version_core(requirement):
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import numpy
|
|
||||||
|
|
||||||
from transformers.testing_utils import TestCasePlus
|
from transformers.testing_utils import TestCasePlus
|
||||||
from transformers.utils.versions import (
|
from transformers.utils.versions import (
|
||||||
importlib_metadata,
|
importlib_metadata,
|
||||||
@@ -25,7 +23,7 @@ from transformers.utils.versions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
numpy_ver = numpy.__version__
|
numpy_ver = importlib_metadata.version("numpy")
|
||||||
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||||
|
|
||||||
|
|
||||||
@@ -54,6 +52,9 @@ class DependencyVersionCheckTest(TestCasePlus):
|
|||||||
# gt
|
# gt
|
||||||
require_version_core("numpy>1.0.0")
|
require_version_core("numpy>1.0.0")
|
||||||
|
|
||||||
|
# mix
|
||||||
|
require_version_core("numpy>1.0.0,<1000")
|
||||||
|
|
||||||
# requirement w/o version
|
# requirement w/o version
|
||||||
require_version_core("numpy")
|
require_version_core("numpy")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user