Skip to content
Next Next commit
Modify setup.py to make it suitable for Windows builds and
multi-architecture builds.
  • Loading branch information
mengqin committed Dec 7, 2025
commit 2e0b94810e4613bab15dab94cbaaa60c4c0619bf
64 changes: 49 additions & 15 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,33 @@
SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "10.0", "12.0", "12.1"}

# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174",
]
if os.name == "nt":
CXX_FLAGS = ["/Zi", "/O2", "/openmp", "/std:c++17", "/DENABLE_BF16", "/MD", "/permissive-"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
#"-Xptxas=-v",
"-diag-suppress=174",
"-diag-suppress=177",
"-D_WIN32=1",
"-DUSE_CUDA=1",
]
else:
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174",
]

# Append flags from env if provided
cxx_append = os.getenv("CXX_APPEND_FLAGS", "").strip()
Expand Down Expand Up @@ -140,19 +156,31 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"CUDA 12.8 or higher is required for compute capability 12.0.")

# Add target compute capabilities to NVCC flags.
NVCC_FLAGS_SM80 = list(NVCC_FLAGS)
NVCC_FLAGS_SM89 = list(NVCC_FLAGS)
NVCC_FLAGS_SM90 = list(NVCC_FLAGS)
for capability in compute_capabilities:
if capability.startswith("8.0"):
HAS_SM80 = True
num = "80"
NVCC_FLAGS_SM80 += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS_SM80 += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
elif capability.startswith("8.6"):
HAS_SM86 = True
num = "86"
elif capability.startswith("8.9"):
HAS_SM89 = True
num = "89"
NVCC_FLAGS_SM89 += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS_SM89 += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
elif capability.startswith("9.0"):
HAS_SM90 = True
num = "90a"
NVCC_FLAGS_SM90 += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"):
NVCC_FLAGS_SM90 += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
elif capability.startswith("10.0"):
HAS_SM100 = True
num = "100a"
Expand All @@ -179,7 +207,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"csrc/qattn/pybind_sm80.cpp",
"csrc/qattn/qk_int_sv_f16_cuda_sm80.cu",
],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS_SM80},
)
)

Expand All @@ -197,10 +225,16 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu",
"csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu",
],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS_SM89},
)
)

cuda_lib = []
if os.name == "nt":
cuda_lib = ["cuda.lib"]
else:
cuda_lib = ["-lcuda"]

if HAS_SM90:
ext_modules.append(
CUDAExtension(
Expand All @@ -209,8 +243,8 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"csrc/qattn/pybind_sm90.cpp",
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_link_args=['-lcuda'],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS_SM90},
extra_link_args=cuda_lib,
)
)

Expand Down