yukimama commited on
Commit
877ed39
·
verified ·
1 Parent(s): 4264909
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mamba_block import MambaBlock
3
+ from mamba_config import MambaConfig
4
+ from mamba_layer import MambaLayer
5
+
6
+ # 創建一個Mamba配置
7
+ config = MambaConfig(
8
+ hidden_size=512,
9
+ num_layers=6,
10
+ num_heads=8,
11
+ intermediate_size=2048,
12
+ max_position_embeddings=1024,
13
+ rms_norm=False,
14
+ residual_in_fp32=False,
15
+ fused_add_norm=False,
16
+ )
17
+
18
+ # 創建一個Mamba模型
19
+ class MambaModel(torch.nn.Module):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+ self.config = config
23
+ self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)])
24
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
25
+
26
+ def forward(self, hidden_states: torch.Tensor):
27
+ residual = None
28
+ for layer in self.layers:
29
+ hidden_states, residual = layer(hidden_states, residual)
30
+ hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states)
31
+ return hidden_states
32
+
33
+ # 創建模型實例
34
+ mamba_model = MambaModel(config)
35
+ mamba_model.eval()
36
+
37
+ # Function to generate text from a given prompt using the Mamba model
38
+ def generate_text(prompt, model, max_length=50):
39
+ # 這裡假設你的prompt已經被轉換為嵌入向量
40
+ hidden_states = torch.randn(1, len(prompt), config.hidden_size) # 假設你的輸入序列長度是len(prompt)
41
+
42
+ with torch.no_grad():
43
+ output = model(hidden_states)
44
+
45
+ # 這裡你需要將模型輸出轉換為可讀的文本
46
+ # 這只是一個示例,實際上你可能需要一個解碼器來將輸出轉換為文本
47
+ generated_text = "這裡是生成的文本" # 這裡應該是你的實際生成的文本
48
+
49
+ return generated_text
50
+
51
+ # Function to generate text from a given prompt using the Mamba model
52
+ def generate_uncensored_text(prompt, max_length=50):
53
+ mamba_text = generate_text(prompt, mamba_model, max_length)
54
+ return mamba_text
55
+
56
+ # Example usage
57
+ prompt = "I want to generate some uncensored text."
58
+ uncensored_text = generate_uncensored_text(prompt)
59
+ print(uncensored_text)