File size: 2,283 Bytes
8d780f0
 
e164375
 
 
8d780f0
e164375
 
 
 
 
 
 
 
 
 
 
 
8d780f0
 
7aff8d2
 
 
8d780f0
 
e164375
 
 
 
 
 
 
 
 
 
 
 
 
8d780f0
 
 
e164375
8d780f0
 
 
 
e164375
8d780f0
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import torch
from transformers import AutoTokenizer
from model import SmollM
import yaml

device = "cuda" if torch.cuda.is_available() else "cpu"
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

## Speed up with malmul
torch.set_float32_matmul_precision('high')

# Load model and tokenizer
model = SmollM(config['model']['model_config'])

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name_or_path'])

# Load your custom model (adjust as necessary for your model's implementation)
model_path = "model.pt"  # Replace with the path to your model weights
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint)
model.eval()  # Set the model to evaluation mode

def generate_tokens(model, tokenizer, prompt, max_length=50, device="cuda"):
    """Generates output tokens based on a given prompt."""
    model.eval()
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        outputs = input_ids
        for _ in range(max_length):
            logits = model(outputs[:, -1:])
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            outputs = torch.cat([outputs, next_token], dim=1)
            if next_token.item() == tokenizer.eos_token_id:
                break
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Function to tokenize input and generate text
def generate_text(prompt, max_length=50):
    return generate_tokens(model, tokenizer, prompt, max_length, device)


# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# SmoLLM-135M Text Generation Demo")
    gr.Markdown("Provide an input text prompt, and the model will generate text based on it.")

    with gr.Row():
        input_text = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...", lines=2)
        max_len = gr.Slider(label="Max Output Length", minimum=10, maximum=100, value=50, step=5)

    output_text = gr.Textbox(label="Generated Text", lines=5)
    generate_button = gr.Button("Generate")

    generate_button.click(generate_text, inputs=[input_text, max_len], outputs=output_text)

# Run the app
if __name__ == "__main__":
    demo.launch()