Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer | |
from model import SmollM | |
import yaml | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
with open("config.yaml", "r") as f: | |
config = yaml.safe_load(f) | |
## Speed up with malmul | |
torch.set_float32_matmul_precision('high') | |
# Load model and tokenizer | |
model = SmollM(config['model']['model_config']) | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name_or_path']) | |
# Load your custom model (adjust as necessary for your model's implementation) | |
model_path = "model.pt" # Replace with the path to your model weights | |
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
model.load_state_dict(checkpoint) | |
model.eval() # Set the model to evaluation mode | |
def generate_tokens(model, tokenizer, prompt, max_length=50, device="cuda"): | |
"""Generates output tokens based on a given prompt.""" | |
model.eval() | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
with torch.no_grad(): | |
outputs = input_ids | |
for _ in range(max_length): | |
logits = model(outputs[:, -1:]) | |
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) | |
outputs = torch.cat([outputs, next_token], dim=1) | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Function to tokenize input and generate text | |
def generate_text(prompt, max_length=50): | |
return generate_tokens(model, tokenizer, prompt, max_length, device) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# SmoLLM-135M Text Generation Demo") | |
gr.Markdown("Provide an input text prompt, and the model will generate text based on it.") | |
with gr.Row(): | |
input_text = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...", lines=2) | |
max_len = gr.Slider(label="Max Output Length", minimum=10, maximum=100, value=50, step=5) | |
output_text = gr.Textbox(label="Generated Text", lines=5) | |
generate_button = gr.Button("Generate") | |
generate_button.click(generate_text, inputs=[input_text, max_len], outputs=output_text) | |
# Run the app | |
if __name__ == "__main__": | |
demo.launch() | |