yukimama commited on
Commit
a330b1e
·
verified ·
1 Parent(s): 400a6fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -159
app.py CHANGED
@@ -1,159 +1,16 @@
1
- # Copyright (c) 2023, Albert Gu, Tri Dao.
2
- import warnings
3
- import os
4
- from pathlib import Path
5
-
6
- from packaging.version import parse, Version
7
- from setuptools import setup, find_packages
8
- import subprocess
9
-
10
-
11
- import torch
12
- from torch.utils.cpp_extension import (
13
- BuildExtension,
14
- CppExtension,
15
- CUDAExtension,
16
- CUDA_HOME,
17
- )
18
-
19
- PACKAGE_NAME = "blackmamba"
20
- VERSION = "0.0.1"
21
-
22
- with open("README.md", "r", encoding="utf-8") as fh:
23
- long_description = fh.read()
24
-
25
-
26
- # ninja build does not work unless include_dirs are abs path
27
- this_dir = os.path.dirname(os.path.abspath(__file__))
28
-
29
- # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
30
- # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
31
- FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
32
- SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
33
- # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
34
- FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
35
-
36
-
37
- def get_cuda_bare_metal_version(cuda_dir):
38
- raw_output = subprocess.check_output(
39
- [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
40
- )
41
- output = raw_output.split()
42
- release_idx = output.index("release") + 1
43
- bare_metal_version = parse(output[release_idx].split(",")[0])
44
-
45
- return raw_output, bare_metal_version
46
-
47
-
48
- def check_if_cuda_home_none(global_option: str) -> None:
49
- if CUDA_HOME is not None:
50
- return
51
- # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
52
- # in that case.
53
- warnings.warn(
54
- f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
55
- "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
56
- "only images whose names contain 'devel' will provide nvcc."
57
- )
58
-
59
-
60
- def append_nvcc_threads(nvcc_extra_args):
61
- return nvcc_extra_args + ["--threads", "4"]
62
-
63
-
64
- ext_modules = []
65
- if not SKIP_CUDA_BUILD:
66
- print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
67
- TORCH_MAJOR = int(torch.__version__.split(".")[0])
68
- TORCH_MINOR = int(torch.__version__.split(".")[1])
69
-
70
- check_if_cuda_home_none(PACKAGE_NAME)
71
- # Check, if CUDA11 is installed for compute capability 8.0
72
- cc_flag = []
73
- if CUDA_HOME is not None:
74
- _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
75
- if bare_metal_version < Version("11.6"):
76
- raise RuntimeError(
77
- f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
78
- "Note: make sure nvcc has a supported version by running nvcc -V."
79
- )
80
-
81
- cc_flag.append("-gencode")
82
- cc_flag.append("arch=compute_70,code=sm_70")
83
- cc_flag.append("-gencode")
84
- cc_flag.append("arch=compute_80,code=sm_80")
85
- if bare_metal_version >= Version("11.8"):
86
- cc_flag.append("-gencode")
87
- cc_flag.append("arch=compute_90,code=sm_90")
88
-
89
- # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
90
- # torch._C._GLIBCXX_USE_CXX11_ABI
91
- # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
92
- if FORCE_CXX11_ABI:
93
- torch._C._GLIBCXX_USE_CXX11_ABI = True
94
-
95
- ext_modules.append(
96
- CUDAExtension(
97
- name="selective_scan_cuda",
98
- sources=[
99
- "csrc/selective_scan/selective_scan.cpp",
100
- "csrc/selective_scan/selective_scan_fwd_fp32.cu",
101
- "csrc/selective_scan/selective_scan_fwd_fp16.cu",
102
- "csrc/selective_scan/selective_scan_fwd_bf16.cu",
103
- "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
104
- "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
105
- "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
106
- "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
107
- "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
108
- "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
109
- ],
110
- extra_compile_args={
111
- "cxx": ["-O3", "-std=c++17"],
112
- "nvcc": append_nvcc_threads(
113
- [
114
- "-O3",
115
- "-std=c++17",
116
- "-U__CUDA_NO_HALF_OPERATORS__",
117
- "-U__CUDA_NO_HALF_CONVERSIONS__",
118
- "-U__CUDA_NO_BFLOAT16_OPERATORS__",
119
- "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
120
- "-U__CUDA_NO_BFLOAT162_OPERATORS__",
121
- "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
122
- "--expt-relaxed-constexpr",
123
- "--expt-extended-lambda",
124
- "--use_fast_math",
125
- "--ptxas-options=-v",
126
- "-lineinfo",
127
- ]
128
- + cc_flag
129
- ),
130
- },
131
- include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
132
- )
133
- )
134
-
135
-
136
- setup(
137
- name=PACKAGE_NAME,
138
- version=VERSION,
139
- description="Blackmamba state-space + MoE model",
140
- long_description=long_description,
141
- long_description_content_type="text/markdown",
142
- packages=find_packages(include=['ops'],),
143
- exclude=(
144
- "csrc",
145
- "blackmamba.egg-info",
146
- ),
147
- ext_modules=ext_modules,
148
- cmdclass={"build_ext": BuildExtension},
149
- python_requires=">=3.7",
150
- install_requires=[
151
- "torch",
152
- "packaging",
153
- "ninja",
154
- "einops",
155
- "triton",
156
- "transformers",
157
- "causal_conv1d>=1.1.0",
158
- ],
159
- )
 
1
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
2
+
3
+ # Load the pre-trained WormGPT model and tokenizer
4
+ model = GPT2LMHeadModel.from_pretrained("wormgpt")
5
+ tokenizer = GPT2Tokenizer.from_pretrained("wormgpt")
6
+
7
+ def generate_text(prompt, max_length=50):
8
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
9
+ output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
10
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
11
+ return generated_text
12
+
13
+ # Example usage
14
+ prompt = "Generate malicious code for a virus."
15
+ malicious_code = generate_text(prompt)
16
+ print(malicious_code)