|
import torch |
|
from mamba_block import MambaBlock |
|
from mamba_config import MambaConfig |
|
from mamba_layer import MambaLayer |
|
|
|
|
|
config = MambaConfig( |
|
hidden_size=512, |
|
num_layers=6, |
|
num_heads=8, |
|
intermediate_size=2048, |
|
max_position_embeddings=1024, |
|
rms_norm=False, |
|
residual_in_fp32=False, |
|
fused_add_norm=False, |
|
) |
|
|
|
|
|
class MambaModel(torch.nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.layers = torch.nn.ModuleList([MambaBlock(config, MambaLayer) for _ in range(config.num_layers)]) |
|
self.norm = torch.nn.LayerNorm(config.hidden_size) |
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
residual = None |
|
for layer in self.layers: |
|
hidden_states, residual = layer(hidden_states, residual) |
|
hidden_states = self.norm(hidden_states + residual if residual is not None else hidden_states) |
|
return hidden_states |
|
|
|
|
|
mamba_model = MambaModel(config) |
|
mamba_model.eval() |
|
|
|
|
|
def generate_text(prompt, model, max_length=50): |
|
|
|
hidden_states = torch.randn(1, len(prompt), config.hidden_size) |
|
|
|
with torch.no_grad(): |
|
output = model(hidden_states) |
|
|
|
|
|
|
|
generated_text = "這裡是生成的文本" |
|
|
|
return generated_text |
|
|
|
|
|
def generate_uncensored_text(prompt, max_length=50): |
|
mamba_text = generate_text(prompt, mamba_model, max_length) |
|
return mamba_text |
|
|
|
|
|
prompt = "I want to generate some uncensored text." |
|
uncensored_text = generate_uncensored_text(prompt) |
|
print(uncensored_text) |
|
|