next-token / app.py
davanstrien's picture
davanstrien HF Staff
Update app.py
b24cb59 verified
raw
history blame
2.43 kB
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def get_next_token_probs(text):
# Handle empty input
if not text.strip():
return ["No input text"] * 5
# Tokenize input
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for next token
next_token_logits = logits[0, -1, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top-5 tokens and their probabilities
topk_probs, topk_indices = torch.topk(next_token_probs, 5)
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Format the results as strings
formatted_results = []
for i, (token, prob) in enumerate(zip(topk_tokens, topk_probs)):
# Format probability as percentage with 1 decimal place
prob_percent = f"{prob.item()*100:.1f}%"
# Clean up token display (replace space with visible space symbol)
display_token = token.replace(" ", "␣")
# Format the output string
formatted_results.append(f"{i+1}. \"{display_token}\" ({prob_percent})")
return formatted_results
# Create minimal interface with simpler components
with gr.Blocks(css="footer {display: none}") as demo:
gr.Markdown("### GPT-2 Next Token Predictor")
# Input textbox
input_text = gr.Textbox(
label="Text Input",
placeholder="Type here and watch predictions update...",
value="The weather tomorrow will be"
)
# Simple header for results
gr.Markdown("##### Most likely next tokens:")
# Individual output textboxes for each token
token1 = gr.Markdown()
token2 = gr.Markdown()
token3 = gr.Markdown()
token4 = gr.Markdown()
token5 = gr.Markdown()
token_outputs = [token1, token2, token3, token4, token5]
# Set up the live update
input_text.change(
fn=get_next_token_probs,
inputs=input_text,
outputs=token_outputs
)
# Initialize with default text
demo.load(
fn=get_next_token_probs,
inputs=input_text,
outputs=token_outputs
)
# Launch the app
demo.launch()