next-token / app.py
davanstrien's picture
davanstrien HF Staff
Update app.py
5ced46c verified
raw
history blame
3.03 kB
import spaces
import gradio as gr
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import matplotlib.pyplot as plt
import seaborn as sns
import os
# Load model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
@spaces.GPU
def get_token_probabilities(text, top_k=10):
# Tokenize the input text
input_ids = tokenizer.encode(text, return_tensors="pt")
# Get the last token's position
last_token_position = input_ids.shape[1] - 1
# Get model predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Get probabilities for the next token after the last token
next_token_logits = logits[0, last_token_position, :]
next_token_probs = torch.softmax(next_token_logits, dim=0)
# Get top k most likely tokens
topk_probs, topk_indices = torch.topk(next_token_probs, top_k)
# Convert to numpy for easier handling
topk_probs = topk_probs.numpy()
topk_indices = topk_indices.numpy()
# Decode tokens
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices]
# Create a plot
plt.figure(figsize=(10, 6))
sns.barplot(x=topk_probs, y=topk_tokens)
plt.title(f"Top {top_k} token probabilities after: '{text}'")
plt.xlabel("Probability")
plt.ylabel("Tokens")
plt.tight_layout()
# Ensure temp directory exists
os.makedirs("tmp", exist_ok=True)
# Save the plot to a file in the temp directory
plot_path = os.path.join("tmp", "token_probabilities.png")
plt.savefig(plot_path)
plt.close()
return plot_path, dict(zip(topk_tokens, topk_probs.tolist()))
with gr.Blocks() as demo:
gr.Markdown("# GPT-2 Next Token Probability Visualizer")
gr.Markdown("Enter text and see the probabilities of possible next tokens.")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input Text",
placeholder="Type some text here...",
value="Hello, my name is"
)
top_k = gr.Slider(
minimum=5,
maximum=20,
value=10,
step=1,
label="Number of top tokens to show"
)
btn = gr.Button("Generate Probabilities")
with gr.Column():
output_image = gr.Image(label="Probability Distribution")
output_table = gr.JSON(label="Token Probabilities")
btn.click(
fn=get_token_probabilities,
inputs=[input_text, top_k],
outputs=[output_image, output_table]
)
gr.Examples(
examples=[
["Hello, my name is", 10],
["The capital of France is", 10],
["Once upon a time", 10],
["The best way to learn is to", 10]
],
inputs=[input_text, top_k],
)
# Launch the app
demo.launch()