Skip to content
This repository was archived by the owner on Feb 18, 2026. It is now read-only.

Commit 3882b5e

Browse files
committed
pre-commit
1 parent f039153 commit 3882b5e

14 files changed

Lines changed: 75 additions & 107 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,5 +72,6 @@ exclude: |
7272
^third_party/amd/backend/include/roctracer/|
7373
^third_party/amd/backend/lib/|
7474
^third_party/nvidia/backend/include/cuda.h|
75+
^third_party/dlfcn|
7576
^third_party/f2reduce
7677
)

include/triton/Dialect/Triton/IR/Utility.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ SmallVector<T> convertType(const VecU &in) {
2020
}
2121

2222
template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
23-
return std::accumulate(arr.begin(), arr.end(), static_cast<Int>(1), std::multiplies{});
23+
return std::accumulate(arr.begin(), arr.end(), static_cast<Int>(1),
24+
std::multiplies{});
2425
}
2526
template <typename VecT> auto product(const VecT &vec) {
2627
return product(llvm::ArrayRef(vec));

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,8 +569,8 @@ LogicalResult LoopPipelinerInternal::createKernel(
569569
auto stageDef = stages.find(def);
570570
if (stageDef == stages.end() || stageDef->second == useStage)
571571
continue;
572-
auto remap = loopArgMap.find(
573-
std::make_pair(operand->get(), static_cast<unsigned>(useStage) - stageDef->second));
572+
auto remap = loopArgMap.find(std::make_pair(
573+
operand->get(), static_cast<unsigned>(useStage) - stageDef->second));
574574
assert(remap != loopArgMap.end());
575575
nestedNewOp->setOperand(operand->getOperandNumber(),
576576
newForOp.getRegionIterArgs()[remap->second]);

lib/Tools/LayoutUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ ensureLayoutNotLargerThan(const LinearLayout &layout,
7979
continue;
8080
}
8181
assert(llvm::isPowerOf2_32(outValue));
82-
sortedBases.emplace_back(inDimName, static_cast<int>(basisIdx), outValue);
82+
sortedBases.emplace_back(inDimName, static_cast<int>(basisIdx),
83+
outValue);
8384
}
8485
}
8586
// From the largest basis to the smallest.

python/test/unit/language/test_matmul.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,10 @@ def get_src_element_ty_size(dtype_str):
8787

8888
@pytest.mark.parametrize("dtype_src_str", ["float32", "tensorfloat32", "float16", "float8e5"])
8989
@pytest.mark.parametrize("dtype_dst_str", ["float32", "float16"])
90-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 16, 4), (64, 128, 32, 4), (32, 32, 32, 4),
9190
# Skip 256x32@32x128 because it's too large on GPU with max_shared_mem = 101376
92-
# (256, 128, 32, 4), (64, 512, 32, 2),
93-
(64, 512, 32, 2),
94-
(512, 64, 32, 2), (64, 16, 16, 4)])
91+
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES", [(128, 128, 16, 4), (64, 128, 32, 4), (32, 32, 32, 4),
92+
# (256, 128, 32, 4),
93+
(64, 512, 32, 2), (512, 64, 32, 2), (64, 16, 16, 4)])
9594
@pytest.mark.parametrize("NUM_CTAS", [1, 2])
9695
@pytest.mark.parametrize("NUM_WARPS", [4, 8])
9796
@pytest.mark.parametrize("EPILOGUE_SUBTILE", [True, False])

python/test/unit/language/test_subprocess.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,15 @@ def test_print(func_type: str, data_type: str, device: str):
5454

5555
# On Windows, some info are sent to stdout, so we filter them out
5656
if os.name == "nt":
57-
outs = [line for line in outs if not any(x in line for x in [
58-
"ptxas info",
59-
"bytes stack frame",
60-
"main.c",
61-
"warning C4819: The file contains a character that cannot be represented in the current code page",
62-
"Creating library main.lib and object main.exp",
63-
])]
57+
outs = [
58+
line for line in outs if not any(x in line for x in [
59+
"ptxas info",
60+
"bytes stack frame",
61+
"main.c",
62+
"warning C4819: The file contains a character that cannot be represented in the current code page",
63+
"Creating library main.lib and object main.exp",
64+
])
65+
]
6466

6567
# The total number of elements in the 1-D tensor to print.
6668
N = 128

python/triton/runtime/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
116116
cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries)
117117

118118
try:
119-
ret = subprocess.check_call(cc_cmd)
119+
subprocess.check_call(cc_cmd)
120120
except Exception as e:
121121
print("Failed to compile. cc_cmd:", cc_cmd)
122122
raise e

python/triton/windows_utils.py

Lines changed: 40 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,11 @@ def max_version(
5454

5555

5656
def check_msvc(msvc_base_path: Path, version: str) -> bool:
57-
return all(
58-
x.exists()
59-
for x in [
60-
msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
61-
msvc_base_path / version / "include" / "vcruntime.h",
62-
msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
63-
]
64-
)
57+
return all(x.exists() for x in [
58+
msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
59+
msvc_base_path / version / "include" / "vcruntime.h",
60+
msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
61+
])
6562

6663

6764
def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
@@ -72,20 +69,16 @@ def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
7269

7370
version = os.getenv("VCToolsVersion")
7471
if not check_msvc(msvc_base_path, version):
75-
warnings.warn(
76-
f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
77-
f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
78-
"but this MSVC installation is incomplete."
79-
)
72+
warnings.warn(f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
73+
f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
74+
"but this MSVC installation is incomplete.")
8075
return None, None
8176

8277
return msvc_base_path, version
8378

8479

8580
def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
86-
vswhere_path = find_in_program_files(
87-
r"Microsoft Visual Studio\Installer\vswhere.exe"
88-
)
81+
vswhere_path = find_in_program_files(r"Microsoft Visual Studio\Installer\vswhere.exe")
8982
if vswhere_path is None:
9083
return None, None
9184

@@ -111,9 +104,7 @@ def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
111104
if not msvc_base_path.exists():
112105
return None, None
113106

114-
version = max_version(
115-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
116-
)
107+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
117108
if version is None:
118109
return None, None
119110

@@ -132,9 +123,7 @@ def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
132123
if not msvc_base_path.exists():
133124
continue
134125

135-
version = max_version(
136-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
137-
)
126+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
138127
if version is None:
139128
continue
140129

@@ -153,9 +142,7 @@ def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
153142
paths = sorted(paths)[::-1]
154143
for msvc_base_path in paths:
155144
msvc_base_path = Path(msvc_base_path)
156-
version = max_version(
157-
os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path)
158-
)
145+
version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
159146
if version is None:
160147
continue
161148
return msvc_base_path, version
@@ -188,13 +175,10 @@ def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
188175

189176

190177
def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
191-
return all(
192-
x.exists()
193-
for x in [
194-
winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
195-
winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
196-
]
197-
)
178+
return all(x.exists() for x in [
179+
winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
180+
winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
181+
])
198182

199183

200184
def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
@@ -207,18 +191,14 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
207191
if version is None:
208192
version = os.getenv("WindowsSDKVer")
209193
if version is None:
210-
warnings.warn(
211-
f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
212-
"but WindowsSDKVersion (or WindowsSDKVer) is not set."
213-
)
194+
warnings.warn(f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
195+
"but WindowsSDKVersion (or WindowsSDKVer) is not set.")
214196
return None, None
215197
version = version.rstrip("\\")
216198
if not check_winsdk(winsdk_base_path, version):
217-
warnings.warn(
218-
f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
219-
f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
220-
"but this Windows SDK installation is incomplete."
221-
)
199+
warnings.warn(f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
200+
f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
201+
"but this Windows SDK installation is incomplete.")
222202
return None, None
223203

224204
return winsdk_base_path, version
@@ -227,9 +207,7 @@ def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
227207
def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
228208
try:
229209
reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
230-
key = winreg.OpenKeyEx(
231-
reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0"
232-
)
210+
key = winreg.OpenKeyEx(reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0")
233211
folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
234212
winreg.CloseKey(key)
235213
except OSError:
@@ -296,9 +274,7 @@ def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
296274

297275

298276
@functools.lru_cache
299-
def find_msvc_winsdk(
300-
env_only: bool = False,
301-
) -> tuple[Optional[str], list[str], list[str]]:
277+
def find_msvc_winsdk(env_only: bool = False, ) -> tuple[Optional[str], list[str], list[str]]:
302278
msvc_bin_path, msvc_inc_dirs, msvc_lib_dirs = find_msvc(env_only)
303279
winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk(env_only)
304280
return (
@@ -314,9 +290,9 @@ def find_python() -> list[str]:
314290
if sysconfig.get_config_var("Py_GIL_DISABLED"):
315291
version += "t"
316292
for python_base_path in [
317-
sys.exec_prefix,
318-
sys.base_exec_prefix,
319-
os.path.dirname(sys.executable),
293+
sys.exec_prefix,
294+
sys.base_exec_prefix,
295+
os.path.dirname(sys.executable),
320296
]:
321297
python_lib_dir = Path(python_base_path) / "libs"
322298
if (python_lib_dir / f"python{version}.lib").exists():
@@ -328,44 +304,35 @@ def find_python() -> list[str]:
328304

329305
def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
330306
# pip
331-
if all(
332-
x.exists()
333-
for x in [
307+
if all(x.exists() for x in [
334308
base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
335309
base_path / "cuda_runtime" / "include" / "cuda.h",
336310
base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
337-
]
338-
):
311+
]):
339312
return (
340313
str(base_path / "cuda_nvcc" / "bin"),
341314
[str(base_path / "cuda_runtime" / "include")],
342315
[str(base_path / "cuda_runtime" / "lib" / "x64")],
343316
)
344317

345318
# conda
346-
if all(
347-
x.exists()
348-
for x in [
319+
if all(x.exists() for x in [
349320
base_path / "bin" / "ptxas.exe",
350321
base_path / "include" / "cuda.h",
351322
base_path / "lib" / "cuda.lib",
352-
]
353-
):
323+
]):
354324
return (
355325
str(base_path / "bin"),
356326
[str(base_path / "include")],
357327
[str(base_path / "lib")],
358328
)
359329

360330
# bundled or system-wide
361-
if all(
362-
x.exists()
363-
for x in [
331+
if all(x.exists() for x in [
364332
base_path / "bin" / "ptxas.exe",
365333
base_path / "include" / "cuda.h",
366334
base_path / "lib" / "x64" / "cuda.lib",
367-
]
368-
):
335+
]):
369336
return (
370337
str(base_path / "bin"),
371338
[str(base_path / "include")],
@@ -382,19 +349,15 @@ def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
382349
continue
383350

384351
cuda_base_path = Path(cuda_base_path)
385-
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
386-
cuda_base_path
387-
)
352+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
388353
if cuda_bin_path:
389354
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
390355

391356
return None, [], []
392357

393358

394359
def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
395-
cuda_base_path = (
396-
Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia"
397-
)
360+
cuda_base_path = (Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia")
398361
return check_and_find_cuda(cuda_base_path)
399362

400363

@@ -418,9 +381,7 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
418381
paths = sorted(paths)[::-1]
419382
for cuda_base_path in paths:
420383
cuda_base_path = Path(cuda_base_path)
421-
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(
422-
cuda_base_path
423-
)
384+
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
424385
if cuda_bin_path:
425386
return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
426387

@@ -430,11 +391,11 @@ def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
430391
@functools.lru_cache
431392
def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
432393
for f in [
433-
find_cuda_env,
434-
find_cuda_bundled,
435-
find_cuda_pip,
436-
find_cuda_conda,
437-
find_cuda_hardcoded,
394+
find_cuda_env,
395+
find_cuda_bundled,
396+
find_cuda_pip,
397+
find_cuda_conda,
398+
find_cuda_hardcoded,
438399
]:
439400
cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
440401
if cuda_bin_path:

setup.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,8 @@ def download_and_copy_dependencies():
603603
download_and_copy(
604604
name="nvidia/cupti-" + NVIDIA_TOOLCHAIN_VERSION["cupti"],
605605
# On Windows, each version of CUPTI has a specific DLL name. Remember to update this with `nvidia-toolchain-version.json`.
606-
src_func=lambda system, arch, version: f"cuda_cupti-{system}-{arch}-{version}-archive/lib/cupti64_2025.1.1.dll",
606+
src_func=lambda system, arch, version:
607+
f"cuda_cupti-{system}-{arch}-{version}-archive/lib/cupti64_2025.1.1.dll",
607608
dst_path="third_party/nvidia/backend/lib/cupti/cupti64_2025.1.1.dll",
608609
variable="TRITON_CUPTI_LIB_PATH",
609610
version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
@@ -786,13 +787,12 @@ def get_entry_points():
786787
return entry_points
787788

788789

789-
def get_backend_package_data(path, exclude_dirs=("__pycache__", "include"), include_files=("include/cuda.h",)):
790+
def get_backend_package_data(path, exclude_dirs=("__pycache__", "include"), include_files=("include/cuda.h", )):
790791
if path is None or not os.path.exists(path):
791792
return []
792793
out = [
793-
os.path.join(os.path.relpath(p, path), "*")
794-
for p, _, _, in os.walk(path)
795-
if not any(x in p.split(os.path.sep) for x in exclude_dirs)
794+
os.path.join(os.path.relpath(p, path), "*") for p, _, _, in os.walk(path) if not any(x in p.split(os.path.sep)
795+
for x in exclude_dirs)
796796
]
797797
for x in include_files:
798798
if os.path.exists(os.path.join(path, x)):
@@ -819,7 +819,7 @@ def get_package_data():
819819
for x in os.listdir(backend.tools_dir):
820820
yield (f"triton.tools.extra.{x}", get_backend_package_data(os.path.join(backend.tools_dir, x)))
821821

822-
yield ("triton.runtime", get_backend_package_data("python/triton/runtime", exclude_dirs=("__pycache__",)))
822+
yield ("triton.runtime", get_backend_package_data("python/triton/runtime", exclude_dirs=("__pycache__", )))
823823

824824

825825
def get_git_commit_hash(length=8):

third_party/nvidia/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,8 @@ def make_cubin(self, src, metadata, opt, capability):
433433
try:
434434
# close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
435435
# On Windows, both stdout and stderr need to be redirected to flog
436-
subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog, stderr=flog)
436+
subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
437+
stderr=flog)
437438
except subprocess.CalledProcessError as e:
438439
with open(flog.name) as log_file:
439440
log = log_file.read()

0 commit comments

Comments
 (0)