From 1c151283129cf4b9ae78296f8459c85c8bdf4cb7 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 7 Apr 2021 09:09:38 -0700 Subject: [PATCH] [versions] handle version requirement ranges (#11110) * handle version requirement ranges * add mixed requirement test * cleanup --- src/transformers/utils/versions.py | 49 +++++++++++++++++++----------- tests/test_versions_utils.py | 7 +++-- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/transformers/utils/versions.py b/src/transformers/utils/versions.py index 028dbcc6c8..b573a361b9 100644 --- a/src/transformers/utils/versions.py +++ b/src/transformers/utils/versions.py @@ -40,6 +40,17 @@ ops = { } +def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): + if got_ver is None: + raise ValueError("got_ver is None") + if want_ver is None: + raise ValueError("want_ver is None") + if not ops[op](version.parse(got_ver), version.parse(want_ver)): + raise ImportError( + f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" + ) + + def require_version(requirement: str, hint: Optional[str] = None) -> None: """ Perform a runtime check of the dependency versions, using the exact same syntax used by pip. @@ -51,33 +62,36 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None: hint (:obj:`str`, `optional`): what suggestion to print in case of requirements not being met """ - # note: while pkg_resources.require_version(requirement) is a much simpler way to do it, it - # fails if some of the dependencies of the dependencies are not matching, which is not necessarily - # bad, hence the more complicated check - which also should be faster, since it doesn't check - # dependencies of dependencies. - hint = f"\n{hint}" if hint is not None else "" # non-versioned check if re.match(r"^[\w_\-\d]+$", requirement): pkg, op, want_ver = requirement, None, None else: - match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2})(.+)", requirement) + match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) if not match: raise ValueError( f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" ) - pkg, op, want_ver = match[0] - if op not in ops: - raise ValueError(f"need one of {list(ops.keys())}, but got {op}") + pkg, want_full = match[0] + want_range = want_full.split(",") # there could be multiple requirements + wanted = {} + for w in want_range: + match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) + if not match: + raise ValueError( + f"requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but got {requirement}" + ) + op, want_ver = match[0] + wanted[op] = want_ver + if op not in ops: + raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") # special case if pkg == "python": got_ver = ".".join([str(x) for x in sys.version_info[:3]]) - if not ops[op](version.parse(got_ver), version.parse(want_ver)): - raise ImportError( - f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}." - ) + for op, want_ver in wanted.items(): + _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) return # check if any version is installed @@ -88,11 +102,10 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None: f"The '{requirement}' distribution was not found and is required by this application. {hint}" ) - # check that the right version is installed if version number was provided - if want_ver is not None and not ops[op](version.parse(got_ver), version.parse(want_ver)): - raise ImportError( - f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" - ) + # check that the right version is installed if version number or a range was provided + if want_ver is not None: + for op, want_ver in wanted.items(): + _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) def require_version_core(requirement): diff --git a/tests/test_versions_utils.py b/tests/test_versions_utils.py index 04c6d78ec3..1d488b980b 100644 --- a/tests/test_versions_utils.py +++ b/tests/test_versions_utils.py @@ -14,8 +14,6 @@ import sys -import numpy - from transformers.testing_utils import TestCasePlus from transformers.utils.versions import ( importlib_metadata, @@ -25,7 +23,7 @@ from transformers.utils.versions import ( ) -numpy_ver = numpy.__version__ +numpy_ver = importlib_metadata.version("numpy") python_ver = ".".join([str(x) for x in sys.version_info[:3]]) @@ -54,6 +52,9 @@ class DependencyVersionCheckTest(TestCasePlus): # gt require_version_core("numpy>1.0.0") + # mix + require_version_core("numpy>1.0.0,<1000") + # requirement w/o version require_version_core("numpy")