|
import torch
|
|
from mamba_block import MambaBlock
|
|
from mamba_config import MambaConfig
|
|
from mamba_layer import MambaLayer
|
|
|
|
|
|
config = MambaConfig(
|
|
hidden_size=512,
|
|
num_layers=6,
|
|
num_heads=8,
|
|
intermediate_size=2048,
|
|
max_position_embeddings=1024,
|
|
rms_norm=False,
|
|
residual_in_fp32=False,
|
|
fused_add_norm=False,
|
|
)
|
|
|
|
|
|
class MambaModel(torch.nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)])
|
|
self.norm = torch.nn.LayerNorm(config.hidden_size)
|
|
|
|
def forward(self, hidden_states: torch.Tensor):
|
|
residual = None
|
|
for layer in self.layers:
|
|
hidden_states, residual = layer(hidden_states, residual)
|
|
hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
mamba_model = MambaModel(config)
|
|
mamba_model.eval()
|
|
|
|
|
|
def generate_text(prompt, model, max_length=50):
|
|
|
|
hidden_states = torch.randn(1, len(prompt), config.hidden_size)
|
|
|
|
with torch.no_grad():
|
|
output = model(hidden_states)
|
|
|
|
|
|
|
|
generated_text = "這裡是生成的文本"
|
|
|
|
return generated_text
|
|
|
|
|
|
def generate_uncensored_text(prompt, max_length=50):
|
|
mamba_text = generate_text(prompt, mamba_model, max_length)
|
|
return mamba_text
|
|
|
|
|
|
prompt = "I want to generate some uncensored text."
|
|
uncensored_text = generate_uncensored_text(prompt)
|
|
print(uncensored_text)
|
|
|