yukimama commited on
Commit
400a6fd
·
verified ·
1 Parent(s): 877ed39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -56
app.py CHANGED
@@ -1,59 +1,159 @@
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
- from mamba_block import MambaBlock
3
- from mamba_config import MambaConfig
4
- from mamba_layer import MambaLayer
5
-
6
- # 創建一個Mamba配置
7
- config = MambaConfig(
8
- hidden_size=512,
9
- num_layers=6,
10
- num_heads=8,
11
- intermediate_size=2048,
12
- max_position_embeddings=1024,
13
- rms_norm=False,
14
- residual_in_fp32=False,
15
- fused_add_norm=False,
16
  )
17
 
18
- # 創建一個Mamba模型
19
- class MambaModel(torch.nn.Module):
20
- def __init__(self, config):
21
- super().__init__()
22
- self.config = config
23
- self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)])
24
- self.norm = torch.nn.LayerNorm(config.hidden_size)
25
-
26
- def forward(self, hidden_states: torch.Tensor):
27
- residual = None
28
- for layer in self.layers:
29
- hidden_states, residual = layer(hidden_states, residual)
30
- hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states)
31
- return hidden_states
32
-
33
- # 創建模型實例
34
- mamba_model = MambaModel(config)
35
- mamba_model.eval()
36
-
37
- # Function to generate text from a given prompt using the Mamba model
38
- def generate_text(prompt, model, max_length=50):
39
- # 這裡假設你的prompt已經被轉換為嵌入向量
40
- hidden_states = torch.randn(1, len(prompt), config.hidden_size) # 假設你的輸入序列長度是len(prompt)
41
-
42
- with torch.no_grad():
43
- output = model(hidden_states)
44
-
45
- # 這裡你需要將模型輸出轉換為可讀的文本
46
- # 這只是一個示例,實際上你可能需要一個解碼器來將輸出轉換為文本
47
- generated_text = "這裡是生成的文本" # 這裡應該是你的實際生成的文本
48
-
49
- return generated_text
50
-
51
- # Function to generate text from a given prompt using the Mamba model
52
- def generate_uncensored_text(prompt, max_length=50):
53
- mamba_text = generate_text(prompt, mamba_model, max_length)
54
- return mamba_text
55
-
56
- # Example usage
57
- prompt = "I want to generate some uncensored text."
58
- uncensored_text = generate_uncensored_text(prompt)
59
- print(uncensored_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, Albert Gu, Tri Dao.
2
+ import warnings
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from packaging.version import parse, Version
7
+ from setuptools import setup, find_packages
8
+ import subprocess
9
+
10
+
11
  import torch
12
+ from torch.utils.cpp_extension import (
13
+ BuildExtension,
14
+ CppExtension,
15
+ CUDAExtension,
16
+ CUDA_HOME,
 
 
 
 
 
 
 
 
 
17
  )
18
 
19
+ PACKAGE_NAME = "blackmamba"
20
+ VERSION = "0.0.1"
21
+
22
+ with open("README.md", "r", encoding="utf-8") as fh:
23
+ long_description = fh.read()
24
+
25
+
26
+ # ninja build does not work unless include_dirs are abs path
27
+ this_dir = os.path.dirname(os.path.abspath(__file__))
28
+
29
+ # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
30
+ # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
31
+ FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
32
+ SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
33
+ # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
34
+ FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
35
+
36
+
37
+ def get_cuda_bare_metal_version(cuda_dir):
38
+ raw_output = subprocess.check_output(
39
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
40
+ )
41
+ output = raw_output.split()
42
+ release_idx = output.index("release") + 1
43
+ bare_metal_version = parse(output[release_idx].split(",")[0])
44
+
45
+ return raw_output, bare_metal_version
46
+
47
+
48
+ def check_if_cuda_home_none(global_option: str) -> None:
49
+ if CUDA_HOME is not None:
50
+ return
51
+ # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
52
+ # in that case.
53
+ warnings.warn(
54
+ f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
55
+ "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
56
+ "only images whose names contain 'devel' will provide nvcc."
57
+ )
58
+
59
+
60
+ def append_nvcc_threads(nvcc_extra_args):
61
+ return nvcc_extra_args + ["--threads", "4"]
62
+
63
+
64
+ ext_modules = []
65
+ if not SKIP_CUDA_BUILD:
66
+ print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
67
+ TORCH_MAJOR = int(torch.__version__.split(".")[0])
68
+ TORCH_MINOR = int(torch.__version__.split(".")[1])
69
+
70
+ check_if_cuda_home_none(PACKAGE_NAME)
71
+ # Check, if CUDA11 is installed for compute capability 8.0
72
+ cc_flag = []
73
+ if CUDA_HOME is not None:
74
+ _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
75
+ if bare_metal_version < Version("11.6"):
76
+ raise RuntimeError(
77
+ f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
78
+ "Note: make sure nvcc has a supported version by running nvcc -V."
79
+ )
80
+
81
+ cc_flag.append("-gencode")
82
+ cc_flag.append("arch=compute_70,code=sm_70")
83
+ cc_flag.append("-gencode")
84
+ cc_flag.append("arch=compute_80,code=sm_80")
85
+ if bare_metal_version >= Version("11.8"):
86
+ cc_flag.append("-gencode")
87
+ cc_flag.append("arch=compute_90,code=sm_90")
88
+
89
+ # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
90
+ # torch._C._GLIBCXX_USE_CXX11_ABI
91
+ # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
92
+ if FORCE_CXX11_ABI:
93
+ torch._C._GLIBCXX_USE_CXX11_ABI = True
94
+
95
+ ext_modules.append(
96
+ CUDAExtension(
97
+ name="selective_scan_cuda",
98
+ sources=[
99
+ "csrc/selective_scan/selective_scan.cpp",
100
+ "csrc/selective_scan/selective_scan_fwd_fp32.cu",
101
+ "csrc/selective_scan/selective_scan_fwd_fp16.cu",
102
+ "csrc/selective_scan/selective_scan_fwd_bf16.cu",
103
+ "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
104
+ "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
105
+ "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
106
+ "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
107
+ "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
108
+ "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
109
+ ],
110
+ extra_compile_args={
111
+ "cxx": ["-O3", "-std=c++17"],
112
+ "nvcc": append_nvcc_threads(
113
+ [
114
+ "-O3",
115
+ "-std=c++17",
116
+ "-U__CUDA_NO_HALF_OPERATORS__",
117
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
118
+ "-U__CUDA_NO_BFLOAT16_OPERATORS__",
119
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
120
+ "-U__CUDA_NO_BFLOAT162_OPERATORS__",
121
+ "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
122
+ "--expt-relaxed-constexpr",
123
+ "--expt-extended-lambda",
124
+ "--use_fast_math",
125
+ "--ptxas-options=-v",
126
+ "-lineinfo",
127
+ ]
128
+ + cc_flag
129
+ ),
130
+ },
131
+ include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
132
+ )
133
+ )
134
+
135
+
136
+ setup(
137
+ name=PACKAGE_NAME,
138
+ version=VERSION,
139
+ description="Blackmamba state-space + MoE model",
140
+ long_description=long_description,
141
+ long_description_content_type="text/markdown",
142
+ packages=find_packages(include=['ops'],),
143
+ exclude=(
144
+ "csrc",
145
+ "blackmamba.egg-info",
146
+ ),
147
+ ext_modules=ext_modules,
148
+ cmdclass={"build_ext": BuildExtension},
149
+ python_requires=">=3.7",
150
+ install_requires=[
151
+ "torch",
152
+ "packaging",
153
+ "ninja",
154
+ "einops",
155
+ "triton",
156
+ "transformers",
157
+ "causal_conv1d>=1.1.0",
158
+ ],
159
+ )