next-token / app.py
davanstrien's picture
davanstrien HF Staff
Create app.py
823cebb verified
raw
history blame
3.12 kB
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# Load model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
def get_token_probabilities(text, top_k=10):
# Tokenize the input text
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get the last token's position
last_token_position = input_ids.shape[1] - 1
# Get model predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for the next token after the last token
next_token_logits = logits[0, last_token_position, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top k most likely tokens
topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
# Convert to numpy for easier handling
topk_probs = topk_probs.numpy()
topk_indices = topk_indices.numpy()
# Decode tokens
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Create a plot
plt.figure(figsize=(10, 6))
sns.barplot(x=topk_probs, y=topk_tokens)
plt.title(f"Top {top_k} token probabilities after: '{text}'")
plt.xlabel("Probability")
plt.ylabel("Tokens")
plt.tight_layout()
# Save the plot to a file
plt.savefig("token_probabilities.png")
plt.close()
return "token_probabilities.png", dict(zip(topk_tokens, topk_probs.tolist()))
def interface():
with gr.Blocks() as demo:
gr.Markdown("# GPT-2 Next Token Probability Visualizer")
gr.Markdown("Enter text and see the probabilities of possible next tokens.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Type some text here...",
value="Hello, my name is"
)
top_k = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=1,
label="Number of top tokens to show"
)
btn = gr.Button("Generate Probabilities")
with gr.Column():
output_image = gr.Image(label="Probability Distribution")
output_table = gr.JSON(label="Token Probabilities")
btn.click(
fn=get_token_probabilities,
inputs=[input_text, top_k],
outputs=[output_image, output_table]
)
gr.Examples(
examples=[
["Hello, my name is", 10],
["The capital of France is", 10],
["Once upon a time", 10],
["The best way to learn is to", 10]
],
inputs=[input_text, top_k],
)
return demo
if __name__ == "__main__":
demo = interface()
demo.launch()