Spaces:
Running
Running
import spaces | |
import gradio as gr | |
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
# Load model and tokenizer | |
model_name = "gpt2" | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
def get_token_probabilities(text, top_k=10): | |
# Tokenize the input text | |
input_ids = tokenizer.encode(text, return_tensors="pt") | |
# Get the last token's position | |
last_token_position = input_ids.shape[1] - 1 | |
# Get model predictions | |
with torch.no_grad(): | |
outputs = model(input_ids) | |
logits = outputs.logits | |
# Get probabilities for the next token after the last token | |
next_token_logits = logits[0, last_token_position, :] | |
next_token_probs = torch.softmax(next_token_logits, dim=0) | |
# Get top k most likely tokens | |
topk_probs, topk_indices = torch.topk(next_token_probs, top_k) | |
# Convert to numpy for easier handling | |
topk_probs = topk_probs.numpy() | |
topk_indices = topk_indices.numpy() | |
# Decode tokens | |
topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices] | |
# Create a plot | |
plt.figure(figsize=(10, 6)) | |
sns.barplot(x=topk_probs, y=topk_tokens) | |
plt.title(f"Top {top_k} token probabilities after: '{text}'") | |
plt.xlabel("Probability") | |
plt.ylabel("Tokens") | |
plt.tight_layout() | |
# Ensure temp directory exists | |
os.makedirs("tmp", exist_ok=True) | |
# Save the plot to a file in the temp directory | |
plot_path = os.path.join("tmp", "token_probabilities.png") | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path, dict(zip(topk_tokens, topk_probs.tolist())) | |
with gr.Blocks() as demo: | |
gr.Markdown("# GPT-2 Next Token Probability Visualizer") | |
gr.Markdown("Enter text and see the probabilities of possible next tokens.") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Input Text", | |
placeholder="Type some text here...", | |
value="Hello, my name is" | |
) | |
top_k = gr.Slider( | |
minimum=5, | |
maximum=20, | |
value=10, | |
step=1, | |
label="Number of top tokens to show" | |
) | |
btn = gr.Button("Generate Probabilities") | |
with gr.Column(): | |
output_image = gr.Image(label="Probability Distribution") | |
output_table = gr.JSON(label="Token Probabilities") | |
btn.click( | |
fn=get_token_probabilities, | |
inputs=[input_text, top_k], | |
outputs=[output_image, output_table] | |
) | |
gr.Examples( | |
examples=[ | |
["Hello, my name is", 10], | |
["The capital of France is", 10], | |
["Once upon a time", 10], | |
["The best way to learn is to", 10] | |
], | |
inputs=[input_text, top_k], | |
) | |
# Launch the app | |
demo.launch() |