Update app.py
Browse files
app.py
CHANGED
@@ -1,59 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
-
from
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
config = MambaConfig(
|
8 |
-
hidden_size=512,
|
9 |
-
num_layers=6,
|
10 |
-
num_heads=8,
|
11 |
-
intermediate_size=2048,
|
12 |
-
max_position_embeddings=1024,
|
13 |
-
rms_norm=False,
|
14 |
-
residual_in_fp32=False,
|
15 |
-
fused_add_norm=False,
|
16 |
)
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
#
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
import warnings
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from packaging.version import parse, Version
|
7 |
+
from setuptools import setup, find_packages
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
|
11 |
import torch
|
12 |
+
from torch.utils.cpp_extension import (
|
13 |
+
BuildExtension,
|
14 |
+
CppExtension,
|
15 |
+
CUDAExtension,
|
16 |
+
CUDA_HOME,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
|
19 |
+
PACKAGE_NAME = "blackmamba"
|
20 |
+
VERSION = "0.0.1"
|
21 |
+
|
22 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
23 |
+
long_description = fh.read()
|
24 |
+
|
25 |
+
|
26 |
+
# ninja build does not work unless include_dirs are abs path
|
27 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
|
29 |
+
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
30 |
+
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
31 |
+
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
|
32 |
+
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
33 |
+
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
34 |
+
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
35 |
+
|
36 |
+
|
37 |
+
def get_cuda_bare_metal_version(cuda_dir):
|
38 |
+
raw_output = subprocess.check_output(
|
39 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
40 |
+
)
|
41 |
+
output = raw_output.split()
|
42 |
+
release_idx = output.index("release") + 1
|
43 |
+
bare_metal_version = parse(output[release_idx].split(",")[0])
|
44 |
+
|
45 |
+
return raw_output, bare_metal_version
|
46 |
+
|
47 |
+
|
48 |
+
def check_if_cuda_home_none(global_option: str) -> None:
|
49 |
+
if CUDA_HOME is not None:
|
50 |
+
return
|
51 |
+
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
52 |
+
# in that case.
|
53 |
+
warnings.warn(
|
54 |
+
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
55 |
+
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
56 |
+
"only images whose names contain 'devel' will provide nvcc."
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def append_nvcc_threads(nvcc_extra_args):
|
61 |
+
return nvcc_extra_args + ["--threads", "4"]
|
62 |
+
|
63 |
+
|
64 |
+
ext_modules = []
|
65 |
+
if not SKIP_CUDA_BUILD:
|
66 |
+
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
67 |
+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
68 |
+
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
69 |
+
|
70 |
+
check_if_cuda_home_none(PACKAGE_NAME)
|
71 |
+
# Check, if CUDA11 is installed for compute capability 8.0
|
72 |
+
cc_flag = []
|
73 |
+
if CUDA_HOME is not None:
|
74 |
+
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
75 |
+
if bare_metal_version < Version("11.6"):
|
76 |
+
raise RuntimeError(
|
77 |
+
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
78 |
+
"Note: make sure nvcc has a supported version by running nvcc -V."
|
79 |
+
)
|
80 |
+
|
81 |
+
cc_flag.append("-gencode")
|
82 |
+
cc_flag.append("arch=compute_70,code=sm_70")
|
83 |
+
cc_flag.append("-gencode")
|
84 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
85 |
+
if bare_metal_version >= Version("11.8"):
|
86 |
+
cc_flag.append("-gencode")
|
87 |
+
cc_flag.append("arch=compute_90,code=sm_90")
|
88 |
+
|
89 |
+
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
90 |
+
# torch._C._GLIBCXX_USE_CXX11_ABI
|
91 |
+
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
92 |
+
if FORCE_CXX11_ABI:
|
93 |
+
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
94 |
+
|
95 |
+
ext_modules.append(
|
96 |
+
CUDAExtension(
|
97 |
+
name="selective_scan_cuda",
|
98 |
+
sources=[
|
99 |
+
"csrc/selective_scan/selective_scan.cpp",
|
100 |
+
"csrc/selective_scan/selective_scan_fwd_fp32.cu",
|
101 |
+
"csrc/selective_scan/selective_scan_fwd_fp16.cu",
|
102 |
+
"csrc/selective_scan/selective_scan_fwd_bf16.cu",
|
103 |
+
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
|
104 |
+
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
|
105 |
+
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
|
106 |
+
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
|
107 |
+
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
|
108 |
+
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
|
109 |
+
],
|
110 |
+
extra_compile_args={
|
111 |
+
"cxx": ["-O3", "-std=c++17"],
|
112 |
+
"nvcc": append_nvcc_threads(
|
113 |
+
[
|
114 |
+
"-O3",
|
115 |
+
"-std=c++17",
|
116 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
117 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
118 |
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
119 |
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
120 |
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
121 |
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
122 |
+
"--expt-relaxed-constexpr",
|
123 |
+
"--expt-extended-lambda",
|
124 |
+
"--use_fast_math",
|
125 |
+
"--ptxas-options=-v",
|
126 |
+
"-lineinfo",
|
127 |
+
]
|
128 |
+
+ cc_flag
|
129 |
+
),
|
130 |
+
},
|
131 |
+
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
|
132 |
+
)
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
setup(
|
137 |
+
name=PACKAGE_NAME,
|
138 |
+
version=VERSION,
|
139 |
+
description="Blackmamba state-space + MoE model",
|
140 |
+
long_description=long_description,
|
141 |
+
long_description_content_type="text/markdown",
|
142 |
+
packages=find_packages(include=['ops'],),
|
143 |
+
exclude=(
|
144 |
+
"csrc",
|
145 |
+
"blackmamba.egg-info",
|
146 |
+
),
|
147 |
+
ext_modules=ext_modules,
|
148 |
+
cmdclass={"build_ext": BuildExtension},
|
149 |
+
python_requires=">=3.7",
|
150 |
+
install_requires=[
|
151 |
+
"torch",
|
152 |
+
"packaging",
|
153 |
+
"ninja",
|
154 |
+
"einops",
|
155 |
+
"triton",
|
156 |
+
"transformers",
|
157 |
+
"causal_conv1d>=1.1.0",
|
158 |
+
],
|
159 |
+
)
|