diff --git a/CHANGES.md b/CHANGES.md index a8b556cb7e8..5223b537f17 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -16,6 +16,7 @@ - Add trailing commas to collection literals even if there's a comment after the last entry (#3393) +- Fix target version inference (#3583) ### Configuration diff --git a/src/black/files.py b/src/black/files.py index 8c0131126b7..78c1d26b38e 100644 --- a/src/black/files.py +++ b/src/black/files.py @@ -1,7 +1,10 @@ import io +import operator import os +import re import sys -from functools import lru_cache +from enum import Enum +from functools import lru_cache, reduce from pathlib import Path from typing import ( TYPE_CHECKING, @@ -168,46 +171,174 @@ def parse_req_python_version(requires_python: str) -> Optional[List[TargetVersio return None +class Endpoint(Enum): + CLOSED = 1 + OPEN = 2 + + +CLOSED = Endpoint.CLOSED +OPEN = Endpoint.OPEN + + +class Interval: + def __init__(self, left: Endpoint, lower: Any, upper: Any, right: Endpoint): + if not (lower < upper or (lower == upper and left == right == CLOSED)): + raise ValueError("empty interval") + self.left = left + self.lower = lower + self.upper = upper + self.right = right + + +class IntervalSet: + """Represents a union of intervals.""" + + def __init__(self, intervals: List[Any]): + self.intervals = intervals + + def __and__(self, other: "IntervalSet") -> "IntervalSet": + new_intervals = [] + for i1 in self.intervals: + for i2 in other.intervals: + if i1.lower < i2.lower: + lower = i2.lower + left = i2.left + elif i2.lower < i1.lower: + lower = i1.lower + left = i1.left + else: + lower = i1.lower + left = CLOSED if i1.left == i2.left == CLOSED else OPEN + if i1.upper < i2.upper: + upper = i1.upper + right = i1.right + elif i2.upper < i1.upper: + upper = i2.upper + right = i2.right + else: + upper = i1.upper + right = CLOSED if i1.right == i2.right == CLOSED else OPEN + try: + new_intervals.append(Interval(left, lower, upper, right)) + except ValueError: + pass + return IntervalSet(new_intervals) + + def __or__(self, other: "IntervalSet") -> "IntervalSet": + return IntervalSet(self.intervals + other.intervals) + + @property + def empty(self) -> bool: + return len(self.intervals) == 0 + + +def interval(left: Endpoint, lower: Any, upper: Any, right: Endpoint) -> IntervalSet: + try: + return IntervalSet([Interval(left, lower, upper, right)]) + except ValueError: + return empty + + +def singleton(value: Any) -> IntervalSet: + return interval(CLOSED, value, value, CLOSED) + + +empty = IntervalSet([]) +min_ver = Version(f"3.{tuple(TargetVersion)[0].value}") +above_max_ver = Version(f"3.{tuple(TargetVersion)[-1].value + 1}") + + +def get_interval_set(specifier: Specifier) -> IntervalSet: + if specifier.version.endswith(".*"): + assert specifier.operator in ("==", "!=") + wildcard = True + ver = Version(specifier.version[:-2]) + else: + wildcard = False + if specifier.operator != "===": + ver = Version(specifier.version) + + if specifier.operator == ">=": + return interval(CLOSED, ver, above_max_ver, OPEN) + if specifier.operator == ">": + return interval(OPEN, ver, above_max_ver, OPEN) + if specifier.operator == "<=": + return interval(CLOSED, min_ver, ver, CLOSED) + if specifier.operator == "<": + return interval(CLOSED, min_ver, ver, OPEN) + if specifier.operator == "==": + if wildcard: + return interval( + CLOSED, + ver, + Version(".".join(map(str, (*ver.release[:-1], ver.release[-1] + 1)))), + OPEN, + ) + else: + return singleton(ver) + if specifier.operator == "!=": + if wildcard: + return interval(CLOSED, min_ver, ver, OPEN) | interval( + CLOSED, + Version(".".join(map(str, (*ver.release[:-1], ver.release[-1] + 1)))), + above_max_ver, + OPEN, + ) + else: + return interval(CLOSED, min_ver, ver, OPEN) | interval( + OPEN, ver, above_max_ver, OPEN + ) + if specifier.operator == "~=": + return interval( + CLOSED, + ver, + Version(".".join(map(str, (*ver.release[:-2], ver.release[-2] + 1)))), + OPEN, + ) + if specifier.operator == "===": + # This operator should do a simple string equality test. Pip compares + # it with "X.Y.Z", so only if the version in the specifier is in this + # exact format, it has a chance to match. + if re.fullmatch(r"\d+\.\d+\.\d+", specifier.version): + return singleton(Version(specifier.version)) + else: + return empty + raise AssertionError() # pragma: no cover + + def parse_req_python_specifier(requires_python: str) -> Optional[List[TargetVersion]]: """Parse a specifier string (i.e. ``">=3.7,<3.10"``) to a list of TargetVersion. If parsing fails, will raise a packaging.specifiers.InvalidSpecifier error. - If the parsed specifier cannot be mapped to a valid TargetVersion, returns None. + If the parsed specifier is empty or cannot be mapped to a valid TargetVersion, + returns None. """ - specifier_set = strip_specifier_set(SpecifierSet(requires_python)) + specifier_set = SpecifierSet(requires_python) if not specifier_set: + # This means that the specifier has no version clauses. Technically, + # all Python versions are included in this specifier. But because the + # user didn't refer to any specific Python version, we fall back to + # per-file auto-detection. return None - target_version_map = {f"3.{v.value}": v for v in TargetVersion} - compatible_versions: List[str] = list(specifier_set.filter(target_version_map)) - if compatible_versions: - return [target_version_map[v] for v in compatible_versions] - return None - - -def strip_specifier_set(specifier_set: SpecifierSet) -> SpecifierSet: - """Strip minor versions for some specifiers in the specifier set. - - For background on version specifiers, see PEP 440: - https://peps.python.org/pep-0440/#version-specifiers - """ - specifiers = [] - for s in specifier_set: - if "*" in str(s): - specifiers.append(s) - elif s.operator in ["~=", "==", ">=", "==="]: - version = Version(s.version) - stripped = Specifier(f"{s.operator}{version.major}.{version.minor}") - specifiers.append(stripped) - elif s.operator == ">": - version = Version(s.version) - if len(version.release) > 2: - s = Specifier(f">={version.major}.{version.minor}") - specifiers.append(s) - else: - specifiers.append(s) - - return SpecifierSet(",".join(str(s) for s in specifiers)) + # First, we determine the version interval set from the specifier set (the + # clauses in the specifier set are connected by the logical and operator). + # Then, for each supported Python (minor) version, we check whether the + # interval set intersects with the interval for this Python version. + spec_intervals = reduce( + operator.and_, + map(get_interval_set, specifier_set), + interval(CLOSED, min_ver, above_max_ver, OPEN), + ) + target_versions = [ + tv + for tv in TargetVersion + if not (spec_intervals & get_interval_set(Specifier(f"==3.{tv.value}.*"))).empty + ] + if not target_versions: + return None + else: + return target_versions @lru_cache() diff --git a/tests/test_black.py b/tests/test_black.py index e5e17777715..61e378e2448 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -1572,23 +1572,32 @@ def test_parse_pyproject_toml_project_metadata(self) -> None: self.assertEqual(config.get("target_version"), expected) def test_infer_target_version(self) -> None: + ALL_TARGET_VERSIONS = [ + TargetVersion.PY33, + TargetVersion.PY34, + TargetVersion.PY35, + TargetVersion.PY36, + TargetVersion.PY37, + TargetVersion.PY38, + TargetVersion.PY39, + TargetVersion.PY310, + TargetVersion.PY311, + ] for version, expected in [ ("3.6", [TargetVersion.PY36]), ("3.11.0rc1", [TargetVersion.PY311]), (">=3.10", [TargetVersion.PY310, TargetVersion.PY311]), (">=3.10.6", [TargetVersion.PY310, TargetVersion.PY311]), ("<3.6", [TargetVersion.PY33, TargetVersion.PY34, TargetVersion.PY35]), - (">3.7,<3.10", [TargetVersion.PY38, TargetVersion.PY39]), - (">3.7,!=3.8,!=3.9", [TargetVersion.PY310, TargetVersion.PY311]), + (">=3.8,<3.10", [TargetVersion.PY38, TargetVersion.PY39]), + (">=3.8,!=3.8.*,!=3.9.*", [TargetVersion.PY310, TargetVersion.PY311]), ( - "> 3.9.4, != 3.10.3", - [TargetVersion.PY39, TargetVersion.PY310, TargetVersion.PY311], + ">3.7,<3.10", + [TargetVersion.PY37, TargetVersion.PY38, TargetVersion.PY39], ), ( - "!=3.3,!=3.4", + ">3.7,!=3.8,!=3.9", [ - TargetVersion.PY35, - TargetVersion.PY36, TargetVersion.PY37, TargetVersion.PY38, TargetVersion.PY39, @@ -1597,10 +1606,12 @@ def test_infer_target_version(self) -> None: ], ), ( - "==3.*", + "> 3.9.4, != 3.10.3", + [TargetVersion.PY39, TargetVersion.PY310, TargetVersion.PY311], + ), + ( + "!=3.3.*,!=3.4.*", [ - TargetVersion.PY33, - TargetVersion.PY34, TargetVersion.PY35, TargetVersion.PY36, TargetVersion.PY37, @@ -1610,6 +1621,8 @@ def test_infer_target_version(self) -> None: TargetVersion.PY311, ], ), + ("!=3.3,!=3.4", ALL_TARGET_VERSIONS), + ("==3.*", ALL_TARGET_VERSIONS), ("==3.8.*", [TargetVersion.PY38]), (None, None), ("", None), @@ -1620,11 +1633,151 @@ def test_infer_target_version(self) -> None: ("3.2", None), ("2.7.18", None), ("==2.7", None), - (">3.10,<3.11", None), + (">=3.11,<3.11", None), + (">3.10,<3.11", [TargetVersion.PY310]), + # Many of the following test cases were written to test edge cases + # while developing the algorithm. To ensure that they still test + # something useful, be careful when changing the tests when support + # for a Python version is added or removed. E.g. when removing + # Python 3.3 support, make sure to change "3.2" to "3.3" and "3.3" + # to "3.4" in the following test cases. + (">=2.20", ALL_TARGET_VERSIONS), + (">=3", ALL_TARGET_VERSIONS), + (">=3.0", ALL_TARGET_VERSIONS), + (">=3.2", ALL_TARGET_VERSIONS), + (">=3.3", ALL_TARGET_VERSIONS), + (">=3.3.20", ALL_TARGET_VERSIONS), + ( + ">=3.4", + [ + TargetVersion.PY34, + TargetVersion.PY35, + TargetVersion.PY36, + TargetVersion.PY37, + TargetVersion.PY38, + TargetVersion.PY39, + TargetVersion.PY310, + TargetVersion.PY311, + ], + ), + (">=3.11", [TargetVersion.PY311]), + (">=3.11.20", [TargetVersion.PY311]), + (">=3.12", None), + (">=4", None), + (">=3.11,==3.11", [TargetVersion.PY311]), + (">2.20", ALL_TARGET_VERSIONS), + (">3", ALL_TARGET_VERSIONS), + (">3.11", [TargetVersion.PY311]), + (">3.11.20", [TargetVersion.PY311]), + (">3.12", None), + (">4", None), + (">3.11,==3.11", None), + ("<=2", None), + ("<=3", None), + ("<=3.2", None), + ("<=3.3", [TargetVersion.PY33]), + ("<=3.3.1", [TargetVersion.PY33]), + ("<=4", ALL_TARGET_VERSIONS), + ("<=3.11,==3.11", [TargetVersion.PY311]), + ("<2", None), + ("<3", None), + ("<3.3", None), + ("<3.3.1", [TargetVersion.PY33]), + ("<3.3.0.1", [TargetVersion.PY33]), + ("<4", ALL_TARGET_VERSIONS), + ("<3.11,==3.11", None), + ("==2.*", None), + ("==3.*", ALL_TARGET_VERSIONS), + ("==3.0.*", None), + ("==3.2.*", None), + ("==3.3.*", [TargetVersion.PY33]), + ("==3.3.20.*", [TargetVersion.PY33]), + ("==4.*", None), + ("==2", None), + ("==3", None), + ("==3.0", None), + ("==3.2", None), + ("==3.3", [TargetVersion.PY33]), + ("==3.3.20", [TargetVersion.PY33]), + ("==3.4", [TargetVersion.PY34]), + ("==4", None), + ("!=2.*", ALL_TARGET_VERSIONS), + ("!=3.*", None), + ("!=3.0.*", ALL_TARGET_VERSIONS), + ("!=3.2.*", ALL_TARGET_VERSIONS), + ( + "!=3.3.*", + [ + TargetVersion.PY34, + TargetVersion.PY35, + TargetVersion.PY36, + TargetVersion.PY37, + TargetVersion.PY38, + TargetVersion.PY39, + TargetVersion.PY310, + TargetVersion.PY311, + ], + ), + ("!=3.3.20.*", ALL_TARGET_VERSIONS), + ("!=4.*", ALL_TARGET_VERSIONS), + ("!=2", ALL_TARGET_VERSIONS), + ("!=3", ALL_TARGET_VERSIONS), + ("!=3.0", ALL_TARGET_VERSIONS), + ("!=3.2", ALL_TARGET_VERSIONS), + ("!=3.3", ALL_TARGET_VERSIONS), + ("!=3.3.20", ALL_TARGET_VERSIONS), + ("!=4", ALL_TARGET_VERSIONS), + ("~=2", None), + ("~=3", None), + ("~=2.0", None), + ("~=3.0", ALL_TARGET_VERSIONS), + ("~=3.0.0", None), + ("~=3.2", ALL_TARGET_VERSIONS), + ("~=3.3", ALL_TARGET_VERSIONS), + ("~=3.3.0", [TargetVersion.PY33]), + ("~=3.3.0.0", [TargetVersion.PY33]), + ("~=3.3.0.1", [TargetVersion.PY33]), + ( + "~=3.4", + [ + TargetVersion.PY34, + TargetVersion.PY35, + TargetVersion.PY36, + TargetVersion.PY37, + TargetVersion.PY38, + TargetVersion.PY39, + TargetVersion.PY310, + TargetVersion.PY311, + ], + ), + ("~=3.11", [TargetVersion.PY311]), + ("~=3.11.20", [TargetVersion.PY311]), + ("~=3.12", None), + ("~=4.0", None), + ("~=4", None), + ("===2", None), + ("===3", None), + ("===3.0", None), + ("===3.0.0", None), + ("===3", None), + ("===3.3", None), + ("===3.3.0", [TargetVersion.PY33]), + ("===3.3.1", [TargetVersion.PY33]), + ("===3.3.1.0", None), + ("===3.3.1.1", None), + ("===4", None), + (">3.3,<3.3", None), + (">3.3.0,==3.3.1", [TargetVersion.PY33]), + ("==3.3.0,==3.3.1", None), + (">3.3,<=3.3", None), + ("==3.3,!=3.3", None), + ("!=3.4.*,==3.5", [TargetVersion.PY35]), + ("==3.3,!=3.4", [TargetVersion.PY33]), + ("<3.4,==3.4", None), ]: test_toml = {"project": {"requires-python": version}} result = black.files.infer_target_version(test_toml) - self.assertEqual(result, expected) + self.assertEqual(result, expected, version) def test_read_pyproject_toml(self) -> None: test_toml_file = THIS_DIR / "test.toml"