Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add MSVC support to setup.py for Windows builds
- Introduced platform-specific compiler and linker flag handling.
- Adjusted `nvcc` path resolution for Windows.
- Ensured compatibility with MSVC by adding appropriate flags and conditionals.
  • Loading branch information
sfinktah committed Nov 2, 2025
commit 63871657da4c18a231f5df64faca57f3edae5157
61 changes: 43 additions & 18 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,41 @@
# Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "12.0"}

# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174",
]
# Compiler flags (platform-specific)
is_msvc = os.name == "nt"
if is_msvc:
# MSVC-compatible flags
CXX_FLAGS = [
"/O2", "/std:c++17", "/openmp", "/EHsc", "/DNOMINMAX",
"/D_ENABLE_EXTENDED_ALIGNED_STORAGE", "/MP", "/permissive-", "/Zc:__cplusplus",
"/D_WIN32", "/DUSE_CUDA",
]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174",
# Host compiler flags for MSVC
"-Xcompiler", "/DWIN32",
"-Xcompiler", "/DUSE_CUDA",
]
else:
# GCC/Clang flags (Linux/macOS)
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
NVCC_FLAGS = [
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--use_fast_math",
"--threads=8",
"-Xptxas=-v",
"-diag-suppress=174",
]

# Append flags from env if provided
cxx_append = os.getenv("CXX_APPEND_FLAGS", "").strip()
Expand All @@ -66,9 +89,10 @@
if nvcc_append:
NVCC_FLAGS += nvcc_append.split()

ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
ABI = 1 if not is_msvc and getattr(torch._C, "_GLIBCXX_USE_CXX11_ABI", False) else 0
if not is_msvc:
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]

if CUDA_HOME is None:
raise RuntimeError(
Expand All @@ -79,8 +103,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:

Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
"""
nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
universal_newlines=True)
nvcc_name = "nvcc.exe" if os.name == "nt" else "nvcc"
nvcc_path = os.path.join(cuda_dir, "bin", nvcc_name)
nvcc_output = subprocess.check_output([nvcc_path, "-V"], universal_newlines=True)
output = nvcc_output.split()
release_idx = output.index("release") + 1
nvcc_cuda_version = parse(output[release_idx].split(",")[0])
Expand Down Expand Up @@ -202,7 +227,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
"csrc/qattn/qk_int_sv_f8_cuda_sm90.cu",
],
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
extra_link_args=['-lcuda'],
extra_link_args=([] if is_msvc else ['-lcuda']),
)
)

Expand Down