diff --git a/setup.py b/setup.py index ab569e94..9679c82d 100644 --- a/setup.py +++ b/setup.py @@ -45,18 +45,41 @@ # Supported NVIDIA GPU architectures. SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "12.0"} - # Compiler flags. - CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] - NVCC_FLAGS = [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--use_fast_math", - "--threads=8", - "-Xptxas=-v", - "-diag-suppress=174", - ] + # Compiler flags (platform-specific) + is_msvc = os.name == "nt" + if is_msvc: + # MSVC-compatible flags + CXX_FLAGS = [ + "/O2", "/std:c++17", "/openmp", "/EHsc", "/DNOMINMAX", + "/D_ENABLE_EXTENDED_ALIGNED_STORAGE", "/MP", "/permissive-", "/Zc:__cplusplus", + "/D_WIN32", "/DUSE_CUDA", + ] + NVCC_FLAGS = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--use_fast_math", + "--threads=8", + "-Xptxas=-v", + "-diag-suppress=174", + # Host compiler flags for MSVC + "-Xcompiler", "/DWIN32", + "-Xcompiler", "/DUSE_CUDA", + ] + else: + # GCC/Clang flags (Linux/macOS) + CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] + NVCC_FLAGS = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--use_fast_math", + "--threads=8", + "-Xptxas=-v", + "-diag-suppress=174", + ] # Append flags from env if provided cxx_append = os.getenv("CXX_APPEND_FLAGS", "").strip() @@ -66,9 +89,10 @@ if nvcc_append: NVCC_FLAGS += nvcc_append.split() - ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 - CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] - NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + ABI = 1 if not is_msvc and getattr(torch._C, "_GLIBCXX_USE_CXX11_ABI", False) else 0 + if not is_msvc: + CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] if CUDA_HOME is None: raise RuntimeError( @@ -79,8 +103,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py """ - nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], - universal_newlines=True) + nvcc_name = "nvcc.exe" if os.name == "nt" else "nvcc" + nvcc_path = os.path.join(cuda_dir, "bin", nvcc_name) + nvcc_output = subprocess.check_output([nvcc_path, "-V"], universal_newlines=True) output = nvcc_output.split() release_idx = output.index("release") + 1 nvcc_cuda_version = parse(output[release_idx].split(",")[0]) @@ -202,7 +227,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", ], extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, - extra_link_args=['-lcuda'], + extra_link_args=([] if is_msvc else ['-lcuda']), ) )