[versions] handle version requirement ranges (#11110)
* handle version requirement ranges * add mixed requirement test * cleanup
This commit is contained in:
@@ -40,6 +40,17 @@ ops = {
|
||||
}
|
||||
|
||||
|
||||
def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint):
|
||||
if got_ver is None:
|
||||
raise ValueError("got_ver is None")
|
||||
if want_ver is None:
|
||||
raise ValueError("want_ver is None")
|
||||
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
raise ImportError(
|
||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
||||
)
|
||||
|
||||
|
||||
def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||
"""
|
||||
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
||||
@@ -51,33 +62,36 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||
hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met
|
||||
"""
|
||||
|
||||
# note: while pkg_resources.require_version(requirement) is a much simpler way to do it, it
|
||||
# fails if some of the dependencies of the dependencies are not matching, which is not necessarily
|
||||
# bad, hence the more complicated check - which also should be faster, since it doesn't check
|
||||
# dependencies of dependencies.
|
||||
|
||||
hint = f"\n{hint}" if hint is not None else ""
|
||||
|
||||
# non-versioned check
|
||||
if re.match(r"^[\w_\-\d]+$", requirement):
|
||||
pkg, op, want_ver = requirement, None, None
|
||||
else:
|
||||
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2})(.+)", requirement)
|
||||
match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
|
||||
)
|
||||
pkg, op, want_ver = match[0]
|
||||
if op not in ops:
|
||||
raise ValueError(f"need one of {list(ops.keys())}, but got {op}")
|
||||
pkg, want_full = match[0]
|
||||
want_range = want_full.split(",") # there could be multiple requirements
|
||||
wanted = {}
|
||||
for w in want_range:
|
||||
match = re.findall(r"^([\s!=<>]{1,2})(.+)", w)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}"
|
||||
)
|
||||
op, want_ver = match[0]
|
||||
wanted[op] = want_ver
|
||||
if op not in ops:
|
||||
raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}")
|
||||
|
||||
# special case
|
||||
if pkg == "python":
|
||||
got_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||
if not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
raise ImportError(
|
||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}."
|
||||
)
|
||||
for op, want_ver in wanted.items():
|
||||
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
||||
return
|
||||
|
||||
# check if any version is installed
|
||||
@@ -88,11 +102,10 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
|
||||
f"The '{requirement}' distribution was not found and is required by this application. {hint}"
|
||||
)
|
||||
|
||||
# check that the right version is installed if version number was provided
|
||||
if want_ver is not None and not ops[op](version.parse(got_ver), version.parse(want_ver)):
|
||||
raise ImportError(
|
||||
f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}"
|
||||
)
|
||||
# check that the right version is installed if version number or a range was provided
|
||||
if want_ver is not None:
|
||||
for op, want_ver in wanted.items():
|
||||
_compare_versions(op, got_ver, want_ver, requirement, pkg, hint)
|
||||
|
||||
|
||||
def require_version_core(requirement):
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
|
||||
import sys
|
||||
|
||||
import numpy
|
||||
|
||||
from transformers.testing_utils import TestCasePlus
|
||||
from transformers.utils.versions import (
|
||||
importlib_metadata,
|
||||
@@ -25,7 +23,7 @@ from transformers.utils.versions import (
|
||||
)
|
||||
|
||||
|
||||
numpy_ver = numpy.__version__
|
||||
numpy_ver = importlib_metadata.version("numpy")
|
||||
python_ver = ".".join([str(x) for x in sys.version_info[:3]])
|
||||
|
||||
|
||||
@@ -54,6 +52,9 @@ class DependencyVersionCheckTest(TestCasePlus):
|
||||
# gt
|
||||
require_version_core("numpy>1.0.0")
|
||||
|
||||
# mix
|
||||
require_version_core("numpy>1.0.0,<1000")
|
||||
|
||||
# requirement w/o version
|
||||
require_version_core("numpy")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user