@@ -54,14 +54,11 @@ def max_version(
5454
5555
5656def check_msvc (msvc_base_path : Path , version : str ) -> bool :
57- return all (
58- x .exists ()
59- for x in [
60- msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe" ,
61- msvc_base_path / version / "include" / "vcruntime.h" ,
62- msvc_base_path / version / "lib" / "x64" / "vcruntime.lib" ,
63- ]
64- )
57+ return all (x .exists () for x in [
58+ msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe" ,
59+ msvc_base_path / version / "include" / "vcruntime.h" ,
60+ msvc_base_path / version / "lib" / "x64" / "vcruntime.lib" ,
61+ ])
6562
6663
6764def find_msvc_env () -> tuple [Optional [Path ], Optional [str ]]:
@@ -72,20 +69,16 @@ def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
7269
7370 version = os .getenv ("VCToolsVersion" )
7471 if not check_msvc (msvc_base_path , version ):
75- warnings .warn (
76- f"Environment variables VCINSTALLDIR = { os .getenv ('VCINSTALLDIR' )} , "
77- f"VCToolsVersion = { os .getenv ('VCToolsVersion' )} are set, "
78- "but this MSVC installation is incomplete."
79- )
72+ warnings .warn (f"Environment variables VCINSTALLDIR = { os .getenv ('VCINSTALLDIR' )} , "
73+ f"VCToolsVersion = { os .getenv ('VCToolsVersion' )} are set, "
74+ "but this MSVC installation is incomplete." )
8075 return None , None
8176
8277 return msvc_base_path , version
8378
8479
8580def find_msvc_vswhere () -> tuple [Optional [Path ], Optional [str ]]:
86- vswhere_path = find_in_program_files (
87- r"Microsoft Visual Studio\Installer\vswhere.exe"
88- )
81+ vswhere_path = find_in_program_files (r"Microsoft Visual Studio\Installer\vswhere.exe" )
8982 if vswhere_path is None :
9083 return None , None
9184
@@ -111,9 +104,7 @@ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
111104 if not msvc_base_path .exists ():
112105 return None , None
113106
114- version = max_version (
115- os .listdir (msvc_base_path ), check = partial (check_msvc , msvc_base_path )
116- )
107+ version = max_version (os .listdir (msvc_base_path ), check = partial (check_msvc , msvc_base_path ))
117108 if version is None :
118109 return None , None
119110
@@ -132,9 +123,7 @@ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
132123 if not msvc_base_path .exists ():
133124 continue
134125
135- version = max_version (
136- os .listdir (msvc_base_path ), check = partial (check_msvc , msvc_base_path )
137- )
126+ version = max_version (os .listdir (msvc_base_path ), check = partial (check_msvc , msvc_base_path ))
138127 if version is None :
139128 continue
140129
@@ -153,9 +142,7 @@ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
153142 paths = sorted (paths )[::- 1 ]
154143 for msvc_base_path in paths :
155144 msvc_base_path = Path (msvc_base_path )
156- version = max_version (
157- os .listdir (msvc_base_path ), check = partial (check_msvc , msvc_base_path )
158- )
145+ version = max_version (os .listdir (msvc_base_path ), check = partial (check_msvc , msvc_base_path ))
159146 if version is None :
160147 continue
161148 return msvc_base_path , version
@@ -188,13 +175,10 @@ def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
188175
189176
190177def check_winsdk (winsdk_base_path : Path , version : str ) -> bool :
191- return all (
192- x .exists ()
193- for x in [
194- winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h" ,
195- winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib" ,
196- ]
197- )
178+ return all (x .exists () for x in [
179+ winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h" ,
180+ winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib" ,
181+ ])
198182
199183
200184def find_winsdk_env () -> tuple [Optional [Path ], Optional [str ]]:
@@ -207,18 +191,14 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
207191 if version is None :
208192 version = os .getenv ("WindowsSDKVer" )
209193 if version is None :
210- warnings .warn (
211- f"Environment variable WindowsSdkDir = { winsdk_base_path } , "
212- "but WindowsSDKVersion (or WindowsSDKVer) is not set."
213- )
194+ warnings .warn (f"Environment variable WindowsSdkDir = { winsdk_base_path } , "
195+ "but WindowsSDKVersion (or WindowsSDKVer) is not set." )
214196 return None , None
215197 version = version .rstrip ("\\ " )
216198 if not check_winsdk (winsdk_base_path , version ):
217- warnings .warn (
218- f"Environment variables WindowsSdkDir = { winsdk_base_path } , "
219- f"WindowsSDKVersion (or WindowsSDKVer) = { version } are set, "
220- "but this Windows SDK installation is incomplete."
221- )
199+ warnings .warn (f"Environment variables WindowsSdkDir = { winsdk_base_path } , "
200+ f"WindowsSDKVersion (or WindowsSDKVer) = { version } are set, "
201+ "but this Windows SDK installation is incomplete." )
222202 return None , None
223203
224204 return winsdk_base_path , version
@@ -227,9 +207,7 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
227207def find_winsdk_registry () -> tuple [Optional [Path ], Optional [str ]]:
228208 try :
229209 reg = winreg .ConnectRegistry (None , winreg .HKEY_LOCAL_MACHINE )
230- key = winreg .OpenKeyEx (
231- reg , r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0"
232- )
210+ key = winreg .OpenKeyEx (reg , r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0" )
233211 folder = winreg .QueryValueEx (key , "InstallationFolder" )[0 ]
234212 winreg .CloseKey (key )
235213 except OSError :
@@ -296,9 +274,7 @@ def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
296274
297275
298276@functools .lru_cache
299- def find_msvc_winsdk (
300- env_only : bool = False ,
301- ) -> tuple [Optional [str ], list [str ], list [str ]]:
277+ def find_msvc_winsdk (env_only : bool = False , ) -> tuple [Optional [str ], list [str ], list [str ]]:
302278 msvc_bin_path , msvc_inc_dirs , msvc_lib_dirs = find_msvc (env_only )
303279 winsdk_inc_dirs , winsdk_lib_dirs = find_winsdk (env_only )
304280 return (
@@ -314,9 +290,9 @@ def find_python() -> list[str]:
314290 if sysconfig .get_config_var ("Py_GIL_DISABLED" ):
315291 version += "t"
316292 for python_base_path in [
317- sys .exec_prefix ,
318- sys .base_exec_prefix ,
319- os .path .dirname (sys .executable ),
293+ sys .exec_prefix ,
294+ sys .base_exec_prefix ,
295+ os .path .dirname (sys .executable ),
320296 ]:
321297 python_lib_dir = Path (python_base_path ) / "libs"
322298 if (python_lib_dir / f"python{ version } .lib" ).exists ():
@@ -328,44 +304,35 @@ def find_python() -> list[str]:
328304
329305def check_and_find_cuda (base_path : Path ) -> tuple [Optional [str ], list [str ], list [str ]]:
330306 # pip
331- if all (
332- x .exists ()
333- for x in [
307+ if all (x .exists () for x in [
334308 base_path / "cuda_nvcc" / "bin" / "ptxas.exe" ,
335309 base_path / "cuda_runtime" / "include" / "cuda.h" ,
336310 base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib" ,
337- ]
338- ):
311+ ]):
339312 return (
340313 str (base_path / "cuda_nvcc" / "bin" ),
341314 [str (base_path / "cuda_runtime" / "include" )],
342315 [str (base_path / "cuda_runtime" / "lib" / "x64" )],
343316 )
344317
345318 # conda
346- if all (
347- x .exists ()
348- for x in [
319+ if all (x .exists () for x in [
349320 base_path / "bin" / "ptxas.exe" ,
350321 base_path / "include" / "cuda.h" ,
351322 base_path / "lib" / "cuda.lib" ,
352- ]
353- ):
323+ ]):
354324 return (
355325 str (base_path / "bin" ),
356326 [str (base_path / "include" )],
357327 [str (base_path / "lib" )],
358328 )
359329
360330 # bundled or system-wide
361- if all (
362- x .exists ()
363- for x in [
331+ if all (x .exists () for x in [
364332 base_path / "bin" / "ptxas.exe" ,
365333 base_path / "include" / "cuda.h" ,
366334 base_path / "lib" / "x64" / "cuda.lib" ,
367- ]
368- ):
335+ ]):
369336 return (
370337 str (base_path / "bin" ),
371338 [str (base_path / "include" )],
@@ -382,19 +349,15 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
382349 continue
383350
384351 cuda_base_path = Path (cuda_base_path )
385- cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs = check_and_find_cuda (
386- cuda_base_path
387- )
352+ cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs = check_and_find_cuda (cuda_base_path )
388353 if cuda_bin_path :
389354 return cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs
390355
391356 return None , [], []
392357
393358
394359def find_cuda_bundled () -> tuple [Optional [str ], list [str ], list [str ]]:
395- cuda_base_path = (
396- Path (sysconfig .get_paths ()["platlib" ]) / "triton" / "backends" / "nvidia"
397- )
360+ cuda_base_path = (Path (sysconfig .get_paths ()["platlib" ]) / "triton" / "backends" / "nvidia" )
398361 return check_and_find_cuda (cuda_base_path )
399362
400363
@@ -418,9 +381,7 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
418381 paths = sorted (paths )[::- 1 ]
419382 for cuda_base_path in paths :
420383 cuda_base_path = Path (cuda_base_path )
421- cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs = check_and_find_cuda (
422- cuda_base_path
423- )
384+ cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs = check_and_find_cuda (cuda_base_path )
424385 if cuda_bin_path :
425386 return cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs
426387
@@ -430,11 +391,11 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
430391@functools .lru_cache
431392def find_cuda () -> tuple [Optional [str ], list [str ], list [str ]]:
432393 for f in [
433- find_cuda_env ,
434- find_cuda_bundled ,
435- find_cuda_pip ,
436- find_cuda_conda ,
437- find_cuda_hardcoded ,
394+ find_cuda_env ,
395+ find_cuda_bundled ,
396+ find_cuda_pip ,
397+ find_cuda_conda ,
398+ find_cuda_hardcoded ,
438399 ]:
439400 cuda_bin_path , cuda_inc_dirs , cuda_lib_dirs = f ()
440401 if cuda_bin_path :
0 commit comments