Spaces:
Running
Running
File size: 2,283 Bytes
8d780f0 e164375 8d780f0 e164375 8d780f0 7aff8d2 8d780f0 e164375 8d780f0 e164375 8d780f0 e164375 8d780f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
import torch
from transformers import AutoTokenizer
from model import SmollM
import yaml
device = "cuda" if torch.cuda.is_available() else "cpu"
with open("config.yaml", "r") as f:
config = yaml.safe_load(f)
## Speed up with malmul
torch.set_float32_matmul_precision('high')
# Load model and tokenizer
model = SmollM(config['model']['model_config'])
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name_or_path'])
# Load your custom model (adjust as necessary for your model's implementation)
model_path = "model.pt" # Replace with the path to your model weights
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint)
model.eval() # Set the model to evaluation mode
def generate_tokens(model, tokenizer, prompt, max_length=50, device="cuda"):
"""Generates output tokens based on a given prompt."""
model.eval()
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
outputs = input_ids
for _ in range(max_length):
logits = model(outputs[:, -1:])
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
outputs = torch.cat([outputs, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Function to tokenize input and generate text
def generate_text(prompt, max_length=50):
return generate_tokens(model, tokenizer, prompt, max_length, device)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# SmoLLM-135M Text Generation Demo")
gr.Markdown("Provide an input text prompt, and the model will generate text based on it.")
with gr.Row():
input_text = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...", lines=2)
max_len = gr.Slider(label="Max Output Length", minimum=10, maximum=100, value=50, step=5)
output_text = gr.Textbox(label="Generated Text", lines=5)
generate_button = gr.Button("Generate")
generate_button.click(generate_text, inputs=[input_text, max_len], outputs=output_text)
# Run the app
if __name__ == "__main__":
demo.launch()
|