|
18 | 18 | ) |
19 | 19 |
|
20 | 20 | from mypy_extensions import mypyc_attr |
| 21 | +from packaging.specifiers import InvalidSpecifier, Specifier, SpecifierSet |
| 22 | +from packaging.version import InvalidVersion, Version |
21 | 23 | from pathspec import PathSpec |
22 | 24 | from pathspec.patterns.gitwildmatch import GitWildMatchPatternError |
23 | 25 |
|
|
32 | 34 | import tomli as tomllib |
33 | 35 |
|
34 | 36 | from black.handle_ipynb_magics import jupyter_dependencies_are_installed |
| 37 | +from black.mode import TargetVersion |
35 | 38 | from black.output import err |
36 | 39 | from black.report import Report |
37 | 40 |
|
@@ -108,14 +111,103 @@ def find_pyproject_toml(path_search_start: Tuple[str, ...]) -> Optional[str]: |
108 | 111 |
|
109 | 112 | @mypyc_attr(patchable=True) |
110 | 113 | def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: |
111 | | - """Parse a pyproject toml file, pulling out relevant parts for Black |
| 114 | + """Parse a pyproject toml file, pulling out relevant parts for Black. |
112 | 115 |
|
113 | | - If parsing fails, will raise a tomllib.TOMLDecodeError |
| 116 | + If parsing fails, will raise a tomllib.TOMLDecodeError. |
114 | 117 | """ |
115 | 118 | with open(path_config, "rb") as f: |
116 | 119 | pyproject_toml = tomllib.load(f) |
117 | | - config = pyproject_toml.get("tool", {}).get("black", {}) |
118 | | - return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} |
| 120 | + config: Dict[str, Any] = pyproject_toml.get("tool", {}).get("black", {}) |
| 121 | + config = {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} |
| 122 | + |
| 123 | + if "target_version" not in config: |
| 124 | + inferred_target_version = infer_target_version(pyproject_toml) |
| 125 | + if inferred_target_version is not None: |
| 126 | + config["target_version"] = [v.name.lower() for v in inferred_target_version] |
| 127 | + |
| 128 | + return config |
| 129 | + |
| 130 | + |
| 131 | +def infer_target_version( |
| 132 | + pyproject_toml: Dict[str, Any] |
| 133 | +) -> Optional[List[TargetVersion]]: |
| 134 | + """Infer Black's target version from the project metadata in pyproject.toml. |
| 135 | +
|
| 136 | + Supports the PyPA standard format (PEP 621): |
| 137 | + https://packaging.python.org/en/latest/specifications/declaring-project-metadata/#requires-python |
| 138 | +
|
| 139 | + If the target version cannot be inferred, returns None. |
| 140 | + """ |
| 141 | + project_metadata = pyproject_toml.get("project", {}) |
| 142 | + requires_python = project_metadata.get("requires-python", None) |
| 143 | + if requires_python is not None: |
| 144 | + try: |
| 145 | + return parse_req_python_version(requires_python) |
| 146 | + except InvalidVersion: |
| 147 | + pass |
| 148 | + try: |
| 149 | + return parse_req_python_specifier(requires_python) |
| 150 | + except (InvalidSpecifier, InvalidVersion): |
| 151 | + pass |
| 152 | + |
| 153 | + return None |
| 154 | + |
| 155 | + |
| 156 | +def parse_req_python_version(requires_python: str) -> Optional[List[TargetVersion]]: |
| 157 | + """Parse a version string (i.e. ``"3.7"``) to a list of TargetVersion. |
| 158 | +
|
| 159 | + If parsing fails, will raise a packaging.version.InvalidVersion error. |
| 160 | + If the parsed version cannot be mapped to a valid TargetVersion, returns None. |
| 161 | + """ |
| 162 | + version = Version(requires_python) |
| 163 | + if version.release[0] != 3: |
| 164 | + return None |
| 165 | + try: |
| 166 | + return [TargetVersion(version.release[1])] |
| 167 | + except (IndexError, ValueError): |
| 168 | + return None |
| 169 | + |
| 170 | + |
| 171 | +def parse_req_python_specifier(requires_python: str) -> Optional[List[TargetVersion]]: |
| 172 | + """Parse a specifier string (i.e. ``">=3.7,<3.10"``) to a list of TargetVersion. |
| 173 | +
|
| 174 | + If parsing fails, will raise a packaging.specifiers.InvalidSpecifier error. |
| 175 | + If the parsed specifier cannot be mapped to a valid TargetVersion, returns None. |
| 176 | + """ |
| 177 | + specifier_set = strip_specifier_set(SpecifierSet(requires_python)) |
| 178 | + if not specifier_set: |
| 179 | + return None |
| 180 | + |
| 181 | + target_version_map = {f"3.{v.value}": v for v in TargetVersion} |
| 182 | + compatible_versions: List[str] = list(specifier_set.filter(target_version_map)) |
| 183 | + if compatible_versions: |
| 184 | + return [target_version_map[v] for v in compatible_versions] |
| 185 | + return None |
| 186 | + |
| 187 | + |
| 188 | +def strip_specifier_set(specifier_set: SpecifierSet) -> SpecifierSet: |
| 189 | + """Strip minor versions for some specifiers in the specifier set. |
| 190 | +
|
| 191 | + For background on version specifiers, see PEP 440: |
| 192 | + https://peps.python.org/pep-0440/#version-specifiers |
| 193 | + """ |
| 194 | + specifiers = [] |
| 195 | + for s in specifier_set: |
| 196 | + if "*" in str(s): |
| 197 | + specifiers.append(s) |
| 198 | + elif s.operator in ["~=", "==", ">=", "==="]: |
| 199 | + version = Version(s.version) |
| 200 | + stripped = Specifier(f"{s.operator}{version.major}.{version.minor}") |
| 201 | + specifiers.append(stripped) |
| 202 | + elif s.operator == ">": |
| 203 | + version = Version(s.version) |
| 204 | + if len(version.release) > 2: |
| 205 | + s = Specifier(f">={version.major}.{version.minor}") |
| 206 | + specifiers.append(s) |
| 207 | + else: |
| 208 | + specifiers.append(s) |
| 209 | + |
| 210 | + return SpecifierSet(",".join(str(s) for s in specifiers)) |
119 | 211 |
|
120 | 212 |
|
121 | 213 | @lru_cache() |
|
0 commit comments