Spaces:
Sleeping
Sleeping
| import os | |
| import subprocess | |
| from packaging.version import parse, Version | |
| import torch | |
| from setuptools import setup | |
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME | |
| def get_cuda_bare_metal_version(cuda_dir): | |
| raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) | |
| output = raw_output.split() | |
| release_idx = output.index("release") + 1 | |
| bare_metal_version = parse(output[release_idx].split(",")[0]) | |
| return raw_output, bare_metal_version | |
| def append_nvcc_threads(nvcc_extra_args): | |
| _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) | |
| if bare_metal_version >= Version("11.2"): | |
| nvcc_threads = os.getenv("NVCC_THREADS") or "4" | |
| return nvcc_extra_args + ["--threads", nvcc_threads] | |
| return nvcc_extra_args | |
| setup( | |
| name='fused_dense_lib', | |
| ext_modules=[ | |
| CUDAExtension( | |
| name='fused_dense_lib', | |
| sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], | |
| extra_compile_args={ | |
| 'cxx': ['-O3',], | |
| 'nvcc': append_nvcc_threads(['-O3']) | |
| } | |
| ) | |
| ], | |
| cmdclass={ | |
| 'build_ext': BuildExtension | |
| }) | |